diff --git a/api.go b/api.go index d391726..b1d5fde 100644 --- a/api.go +++ b/api.go @@ -9,6 +9,7 @@ import ( "errors" "iter" "net/http" + "reflect" "slices" "time" @@ -24,23 +25,57 @@ const defaultAddr = ":8080" const shutdownTimeout = 10 * time.Second // Engine is the central API server managing route groups and middleware. +// +// Example: +// +// engine, err := api.New(api.WithAddr(":8081")) +// if err != nil { +// panic(err) +// } +// _ = engine.Handler() type Engine struct { - addr string - groups []RouteGroup - middlewares []gin.HandlerFunc - wsHandler http.Handler - sseBroker *SSEBroker - swaggerEnabled bool - swaggerTitle string - swaggerDesc string - swaggerVersion string - pprofEnabled bool - expvarEnabled bool - graphql *graphqlConfig + addr string + groups []RouteGroup + middlewares []gin.HandlerFunc + cacheTTL time.Duration + cacheMaxEntries int + cacheMaxBytes int + wsHandler http.Handler + wsPath string + sseBroker *SSEBroker + swaggerEnabled bool + swaggerTitle string + swaggerSummary string + swaggerDesc string + swaggerVersion string + swaggerPath string + swaggerTermsOfService string + swaggerServers []string + swaggerContactName string + swaggerContactURL string + swaggerContactEmail string + swaggerLicenseName string + swaggerLicenseURL string + swaggerSecuritySchemes map[string]any + swaggerExternalDocsDescription string + swaggerExternalDocsURL string + authentikConfig AuthentikConfig + pprofEnabled bool + expvarEnabled bool + ssePath string + graphql *graphqlConfig + i18nConfig I18nConfig } // New creates an Engine with the given options. // The default listen address is ":8080". +// +// Example: +// +// engine, err := api.New(api.WithAddr(":8081"), api.WithResponseMeta()) +// if err != nil { +// panic(err) +// } func New(opts ...Option) (*Engine, error) { e := &Engine{ addr: defaultAddr, @@ -52,27 +87,54 @@ func New(opts ...Option) (*Engine, error) { } // Addr returns the configured listen address. +// +// Example: +// +// engine, _ := api.New(api.WithAddr(":9090")) +// addr := engine.Addr() func (e *Engine) Addr() string { return e.addr } -// Groups returns all registered route groups. +// Groups returns a copy of all registered route groups. +// +// Example: +// +// groups := engine.Groups() func (e *Engine) Groups() []RouteGroup { - return e.groups + return slices.Clone(e.groups) } // GroupsIter returns an iterator over all registered route groups. +// +// Example: +// +// for group := range engine.GroupsIter() { +// _ = group +// } func (e *Engine) GroupsIter() iter.Seq[RouteGroup] { - return slices.Values(e.groups) + groups := slices.Clone(e.groups) + return slices.Values(groups) } // Register adds a route group to the engine. +// +// Example: +// +// engine.Register(myGroup) func (e *Engine) Register(group RouteGroup) { + if isNilRouteGroup(group) { + return + } e.groups = append(e.groups, group) } // Channels returns all WebSocket channel names from registered StreamGroups. // Groups that do not implement StreamGroup are silently skipped. +// +// Example: +// +// channels := engine.Channels() func (e *Engine) Channels() []string { var channels []string for _, g := range e.groups { @@ -84,9 +146,16 @@ func (e *Engine) Channels() []string { } // ChannelsIter returns an iterator over WebSocket channel names from registered StreamGroups. +// +// Example: +// +// for channel := range engine.ChannelsIter() { +// _ = channel +// } func (e *Engine) ChannelsIter() iter.Seq[string] { + groups := slices.Clone(e.groups) return func(yield func(string) bool) { - for _, g := range e.groups { + for _, g := range groups { if sg, ok := g.(StreamGroup); ok { for _, c := range sg.Channels() { if !yield(c) { @@ -100,12 +169,22 @@ func (e *Engine) ChannelsIter() iter.Seq[string] { // Handler builds the Gin engine and returns it as an http.Handler. // Each call produces a fresh handler reflecting the current set of groups. +// +// Example: +// +// handler := engine.Handler() func (e *Engine) Handler() http.Handler { return e.build() } // Serve starts the HTTP server and blocks until the context is cancelled, // then performs a graceful shutdown allowing in-flight requests to complete. +// +// Example: +// +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// _ = engine.Serve(ctx) func (e *Engine) Serve(ctx context.Context) error { srv := &http.Server{ Addr: e.addr, @@ -120,8 +199,18 @@ func (e *Engine) Serve(ctx context.Context) error { close(errCh) }() - // Block until context is cancelled. - <-ctx.Done() + // Return immediately if the listener fails before shutdown is requested. + select { + case err := <-errCh: + return err + case <-ctx.Done(): + } + + // Signal SSE clients first so their handlers can exit cleanly before the + // HTTP server begins its own shutdown sequence. + if e.sseBroker != nil { + e.sseBroker.Drain() + } // Graceful shutdown with timeout. shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) @@ -139,7 +228,7 @@ func (e *Engine) Serve(ctx context.Context) error { // user-supplied middleware, the health endpoint, and all registered route groups. func (e *Engine) build() *gin.Engine { r := gin.New() - r.Use(gin.Recovery()) + r.Use(recoveryMiddleware()) // Apply user-supplied middleware after recovery but before routes. for _, mw := range e.middlewares { @@ -153,18 +242,21 @@ func (e *Engine) build() *gin.Engine { // Mount each registered group at its base path. for _, g := range e.groups { + if isNilRouteGroup(g) { + continue + } rg := r.Group(g.BasePath()) g.RegisterRoutes(rg) } // Mount WebSocket handler if configured. if e.wsHandler != nil { - r.GET("/ws", wrapWSHandler(e.wsHandler)) + r.GET(resolveWSPath(e.wsPath), wrapWSHandler(e.wsHandler)) } // Mount SSE endpoint if configured. if e.sseBroker != nil { - r.GET("/events", e.sseBroker.Handler()) + r.GET(resolveSSEPath(e.ssePath), e.sseBroker.Handler()) } // Mount GraphQL endpoint if configured. @@ -174,7 +266,7 @@ func (e *Engine) build() *gin.Engine { // Mount Swagger UI if enabled. if e.swaggerEnabled { - registerSwagger(r, e.swaggerTitle, e.swaggerDesc, e.swaggerVersion, e.groups) + registerSwagger(r, e, e.groups) } // Mount pprof profiling endpoints if enabled. @@ -189,3 +281,17 @@ func (e *Engine) build() *gin.Engine { return r } + +func isNilRouteGroup(group RouteGroup) bool { + if group == nil { + return true + } + + value := reflect.ValueOf(group) + switch value.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: + return value.IsNil() + default: + return false + } +} diff --git a/api_test.go b/api_test.go index f4bd8b5..948d353 100644 --- a/api_test.go +++ b/api_test.go @@ -29,6 +29,16 @@ func (h *healthGroup) RegisterRoutes(rg *gin.RouterGroup) { }) } +type panicGroup struct{} + +func (p *panicGroup) Name() string { return "panic" } +func (p *panicGroup) BasePath() string { return "/panic" } +func (p *panicGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/boom", func(c *gin.Context) { + panic("boom") + }) +} + // ── New ───────────────────────────────────────────────────────────────── func TestNew_Good(t *testing.T) { @@ -85,6 +95,28 @@ func TestRegister_Good_MultipleGroups(t *testing.T) { } } +func TestRegister_Good_GroupsReturnsCopy(t *testing.T) { + e, _ := api.New() + first := &healthGroup{} + second := &stubGroup{} + e.Register(first) + e.Register(second) + + groups := e.Groups() + groups[0] = nil + + fresh := e.Groups() + if fresh[0] == nil { + t.Fatal("expected Groups to return a copy, but engine state was mutated") + } + if fresh[0].Name() != first.Name() { + t.Fatalf("expected first group name %q, got %q", first.Name(), fresh[0].Name()) + } + if fresh[1].Name() != "stub" { + t.Fatalf("expected second group name %q, got %q", "stub", fresh[1].Name()) + } +} + // ── Handler ───────────────────────────────────────────────────────────── func TestHandler_Good_HealthEndpoint(t *testing.T) { @@ -149,6 +181,41 @@ func TestHandler_Bad_NotFound(t *testing.T) { } } +func TestHandler_Bad_PanicReturnsEnvelope(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + e.Register(&panicGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/panic/boom", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil { + t.Fatal("expected Error to be non-nil") + } + if resp.Error.Code != "internal_server_error" { + t.Fatalf("expected error code=%q, got %q", "internal_server_error", resp.Error.Code) + } + if resp.Error.Message != "Internal server error" { + t.Fatalf("expected error message=%q, got %q", "Internal server error", resp.Error.Message) + } + if got := w.Header().Get("X-Request-ID"); got == "" { + t.Fatal("expected X-Request-ID header to survive panic recovery") + } +} + // ── Serve + graceful shutdown ─────────────────────────────────────────── func TestServe_Good_GracefulShutdown(t *testing.T) { @@ -202,3 +269,32 @@ func TestServe_Good_GracefulShutdown(t *testing.T) { t.Fatal("Serve did not return within 5 seconds after context cancellation") } } + +func TestServe_Bad_ReturnsListenErrorBeforeCancel(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to reserve port: %v", err) + } + addr := ln.Addr().String() + defer ln.Close() + + e, _ := api.New(api.WithAddr(addr)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- e.Serve(ctx) + }() + + select { + case serveErr := <-errCh: + if serveErr == nil { + t.Fatal("expected Serve to return a listen error, got nil") + } + case <-time.After(2 * time.Second): + cancel() + t.Fatal("Serve did not return promptly after listener failure") + } +} diff --git a/authentik.go b/authentik.go index fa08217..49ae1c7 100644 --- a/authentik.go +++ b/authentik.go @@ -14,6 +14,10 @@ import ( ) // AuthentikConfig holds settings for the Authentik forward-auth integration. +// +// Example: +// +// cfg := api.AuthentikConfig{Issuer: "https://auth.example.com/", ClientID: "core-api"} type AuthentikConfig struct { // Issuer is the OIDC issuer URL (e.g. https://auth.example.com/application/o/my-app/). Issuer string @@ -26,12 +30,32 @@ type AuthentikConfig struct { TrustedProxy bool // PublicPaths lists additional paths that do not require authentication. - // /health and /swagger are always public. + // /health and the configured Swagger UI path are always public. PublicPaths []string } +// AuthentikConfig returns the configured Authentik settings for the engine. +// +// The result snapshots the Engine state at call time and clones slices so +// callers can safely reuse or modify the returned value. +// +// Example: +// +// cfg := engine.AuthentikConfig() +func (e *Engine) AuthentikConfig() AuthentikConfig { + if e == nil { + return AuthentikConfig{} + } + + return cloneAuthentikConfig(e.authentikConfig) +} + // AuthentikUser represents an authenticated user extracted from Authentik // forward-auth headers or a validated JWT. +// +// Example: +// +// user := &api.AuthentikUser{Username: "alice", Groups: []string{"admins"}} type AuthentikUser struct { Username string `json:"username"` Email string `json:"email"` @@ -43,6 +67,10 @@ type AuthentikUser struct { } // HasGroup reports whether the user belongs to the named group. +// +// Example: +// +// user.HasGroup("admins") func (u *AuthentikUser) HasGroup(group string) bool { return slices.Contains(u.Groups, group) } @@ -53,6 +81,10 @@ const authentikUserKey = "authentik_user" // GetUser retrieves the AuthentikUser from the Gin context. // Returns nil when no user has been set (unauthenticated request or // middleware not active). +// +// Example: +// +// user := api.GetUser(c) func GetUser(c *gin.Context) *AuthentikUser { val, exists := c.Get(authentikUserKey) if !exists { @@ -134,7 +166,7 @@ func validateJWT(ctx context.Context, cfg AuthentikConfig, rawToken string) (*Au // The middleware is PERMISSIVE: it populates the context when credentials are // present but never rejects unauthenticated requests. Downstream handlers // use GetUser to check authentication. -func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { +func authentikMiddleware(cfg AuthentikConfig, publicPaths func() []string) gin.HandlerFunc { // Build the set of public paths that skip header extraction entirely. public := map[string]bool{ "/health": true, @@ -148,11 +180,19 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { // Skip public paths. path := c.Request.URL.Path for p := range public { - if strings.HasPrefix(path, p) { + if isPublicPath(path, p) { c.Next() return } } + if publicPaths != nil { + for _, p := range publicPaths() { + if isPublicPath(path, p) { + c.Next() + return + } + } + } // Block 1: Extract user from X-authentik-* forward-auth headers. if cfg.TrustedProxy { @@ -193,9 +233,57 @@ func authentikMiddleware(cfg AuthentikConfig) gin.HandlerFunc { } } +func cloneAuthentikConfig(cfg AuthentikConfig) AuthentikConfig { + out := cfg + out.Issuer = strings.TrimSpace(out.Issuer) + out.ClientID = strings.TrimSpace(out.ClientID) + out.PublicPaths = normalisePublicPaths(cfg.PublicPaths) + return out +} + +// normalisePublicPaths trims whitespace, ensures a leading slash, and removes +// duplicate entries while preserving the first occurrence of each path. +func normalisePublicPaths(paths []string) []string { + if len(paths) == 0 { + return nil + } + + out := make([]string, 0, len(paths)) + seen := make(map[string]struct{}, len(paths)) + + for _, path := range paths { + path = strings.TrimSpace(path) + if path == "" { + continue + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + path = strings.TrimRight(path, "/") + if path == "" { + path = "/" + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + out = append(out, path) + } + + if len(out) == 0 { + return nil + } + + return out +} + // RequireAuth is Gin middleware that rejects unauthenticated requests. // It checks for a user set by the Authentik middleware and returns 401 // when none is present. +// +// Example: +// +// r.GET("/private", api.RequireAuth(), handler) func RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { if GetUser(c) == nil { @@ -210,6 +298,10 @@ func RequireAuth() gin.HandlerFunc { // RequireGroup is Gin middleware that rejects requests from users who do // not belong to the specified group. Returns 401 when no user is present // and 403 when the user lacks the required group membership. +// +// Example: +// +// r.GET("/admin", api.RequireGroup("admins"), handler) func RequireGroup(group string) gin.HandlerFunc { return func(c *gin.Context) { user := GetUser(c) diff --git a/authentik_test.go b/authentik_test.go index ab6c4d8..b44b7c8 100644 --- a/authentik_test.go +++ b/authentik_test.go @@ -221,6 +221,27 @@ func TestHealthBypassesAuthentik_Good(t *testing.T) { } } +func TestPublicPaths_Good_SimilarPrefixDoesNotBypassAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{ + TrustedProxy: true, + PublicPaths: []string{"/public"}, + } + e, _ := api.New(api.WithAuthentik(cfg)) + e.Register(&publicPrefixGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/publicity/secure", nil) + req.Header.Set("X-authentik-username", "alice") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200 for /publicity/secure with auth header, got %d: %s", w.Code, w.Body.String()) + } +} + func TestGetUser_Good_NilContext(t *testing.T) { gin.SetMode(gin.TestMode) @@ -322,6 +343,33 @@ func TestBearerAndAuthentikCoexist_Good(t *testing.T) { } } +func TestAuthentik_Good_CustomSwaggerPathBypassesAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := api.AuthentikConfig{TrustedProxy: true} + e, err := api.New( + api.WithAuthentik(cfg), + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/docs/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for custom swagger path without auth, got %d", resp.StatusCode) + } +} + // ── RequireAuth / RequireGroup ──────────────────────────────────────── func TestRequireAuth_Good(t *testing.T) { @@ -458,3 +506,15 @@ func (g *groupRequireGroup) RegisterRoutes(rg *gin.RouterGroup) { c.JSON(200, api.OK("admin panel")) }) } + +// publicPrefixGroup provides a route that should still be processed by auth +// middleware even though its path shares a prefix with a public path. +type publicPrefixGroup struct{} + +func (g *publicPrefixGroup) Name() string { return "public-prefix" } +func (g *publicPrefixGroup) BasePath() string { return "/publicity" } +func (g *publicPrefixGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/secure", api.RequireAuth(), func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("protected")) + }) +} diff --git a/bridge.go b/bridge.go index 79e2e78..101d199 100644 --- a/bridge.go +++ b/bridge.go @@ -3,12 +3,30 @@ package api import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" "iter" + "net" + "net/http" + "reflect" + "regexp" + "slices" + "strconv" + "unicode/utf8" "github.com/gin-gonic/gin" + + coreerr "dappco.re/go/core/log" ) // ToolDescriptor describes a tool that can be exposed as a REST endpoint. +// +// Example: +// +// desc := api.ToolDescriptor{Name: "ping", Description: "Ping the service"} type ToolDescriptor struct { Name string // Tool name, e.g. "file_read" (becomes POST path segment) Description string // Human-readable description @@ -19,6 +37,10 @@ type ToolDescriptor struct { // ToolBridge converts tool descriptors into REST endpoints and OpenAPI paths. // It implements both RouteGroup and DescribableGroup. +// +// Example: +// +// bridge := api.NewToolBridge("/mcp") type ToolBridge struct { basePath string name string @@ -30,7 +52,14 @@ type boundTool struct { handler gin.HandlerFunc } +var _ RouteGroup = (*ToolBridge)(nil) +var _ DescribableGroup = (*ToolBridge)(nil) + // NewToolBridge creates a bridge that mounts tool endpoints at basePath. +// +// Example: +// +// bridge := api.NewToolBridge("/mcp") func NewToolBridge(basePath string) *ToolBridge { return &ToolBridge{ basePath: basePath, @@ -39,17 +68,39 @@ func NewToolBridge(basePath string) *ToolBridge { } // Add registers a tool with its HTTP handler. +// +// Example: +// +// bridge.Add(api.ToolDescriptor{Name: "ping", Description: "Ping the service"}, handler) func (b *ToolBridge) Add(desc ToolDescriptor, handler gin.HandlerFunc) { + if validator := newToolInputValidator(desc.OutputSchema); validator != nil { + handler = wrapToolResponseHandler(handler, validator) + } + if validator := newToolInputValidator(desc.InputSchema); validator != nil { + handler = wrapToolHandler(handler, validator) + } b.tools = append(b.tools, boundTool{descriptor: desc, handler: handler}) } // Name returns the bridge identifier. +// +// Example: +// +// name := bridge.Name() func (b *ToolBridge) Name() string { return b.name } // BasePath returns the URL prefix for all tool endpoints. +// +// Example: +// +// path := bridge.BasePath() func (b *ToolBridge) BasePath() string { return b.basePath } // RegisterRoutes mounts POST /{tool_name} for each registered tool. +// +// Example: +// +// bridge.RegisterRoutes(rg) func (b *ToolBridge) RegisterRoutes(rg *gin.RouterGroup) { for _, t := range b.tools { rg.POST("/"+t.descriptor.Name, t.handler) @@ -57,44 +108,31 @@ func (b *ToolBridge) RegisterRoutes(rg *gin.RouterGroup) { } // Describe returns OpenAPI route descriptions for all registered tools. +// +// Example: +// +// descs := bridge.Describe() func (b *ToolBridge) Describe() []RouteDescription { - descs := make([]RouteDescription, 0, len(b.tools)) - for _, t := range b.tools { - tags := []string{t.descriptor.Group} - if t.descriptor.Group == "" { - tags = []string{b.name} - } - descs = append(descs, RouteDescription{ - Method: "POST", - Path: "/" + t.descriptor.Name, - Summary: t.descriptor.Description, - Description: t.descriptor.Description, - Tags: tags, - RequestBody: t.descriptor.InputSchema, - Response: t.descriptor.OutputSchema, - }) + tools := b.snapshotTools() + descs := make([]RouteDescription, 0, len(tools)) + for _, tool := range tools { + descs = append(descs, describeTool(tool.descriptor, b.name)) } return descs } // DescribeIter returns an iterator over OpenAPI route descriptions for all registered tools. +// +// Example: +// +// for rd := range bridge.DescribeIter() { +// _ = rd +// } func (b *ToolBridge) DescribeIter() iter.Seq[RouteDescription] { + tools := b.snapshotTools() return func(yield func(RouteDescription) bool) { - for _, t := range b.tools { - tags := []string{t.descriptor.Group} - if t.descriptor.Group == "" { - tags = []string{b.name} - } - rd := RouteDescription{ - Method: "POST", - Path: "/" + t.descriptor.Name, - Summary: t.descriptor.Description, - Description: t.descriptor.Description, - Tags: tags, - RequestBody: t.descriptor.InputSchema, - Response: t.descriptor.OutputSchema, - } - if !yield(rd) { + for _, tool := range tools { + if !yield(describeTool(tool.descriptor, b.name)) { return } } @@ -102,21 +140,778 @@ func (b *ToolBridge) DescribeIter() iter.Seq[RouteDescription] { } // Tools returns all registered tool descriptors. +// +// Example: +// +// descs := bridge.Tools() func (b *ToolBridge) Tools() []ToolDescriptor { - descs := make([]ToolDescriptor, len(b.tools)) - for i, t := range b.tools { + tools := b.snapshotTools() + descs := make([]ToolDescriptor, len(tools)) + for i, t := range tools { descs[i] = t.descriptor } return descs } // ToolsIter returns an iterator over all registered tool descriptors. +// +// Example: +// +// for desc := range bridge.ToolsIter() { +// _ = desc +// } func (b *ToolBridge) ToolsIter() iter.Seq[ToolDescriptor] { + tools := b.snapshotTools() return func(yield func(ToolDescriptor) bool) { - for _, t := range b.tools { - if !yield(t.descriptor) { + for _, tool := range tools { + if !yield(tool.descriptor) { + return + } + } + } +} + +func (b *ToolBridge) snapshotTools() []boundTool { + if len(b.tools) == 0 { + return nil + } + return slices.Clone(b.tools) +} + +func describeTool(desc ToolDescriptor, defaultTag string) RouteDescription { + tags := cleanTags([]string{desc.Group}) + if len(tags) == 0 { + tags = []string{defaultTag} + } + return RouteDescription{ + Method: "POST", + Path: "/" + desc.Name, + Summary: desc.Description, + Description: desc.Description, + Tags: tags, + RequestBody: desc.InputSchema, + Response: desc.OutputSchema, + } +} + +// maxToolRequestBodyBytes is the maximum request body size accepted by the +// tool bridge handler. Requests larger than this are rejected with 413. +const maxToolRequestBodyBytes = 10 << 20 // 10 MiB + +func wrapToolHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { + return func(c *gin.Context) { + limited := http.MaxBytesReader(c.Writer, c.Request.Body, maxToolRequestBodyBytes) + body, err := io.ReadAll(limited) + if err != nil { + status := http.StatusBadRequest + msg := "Unable to read request body" + if err.Error() == "http: request body too large" { + status = http.StatusRequestEntityTooLarge + msg = "Request body exceeds the maximum allowed size" + } + c.AbortWithStatusJSON(status, FailWithDetails( + "invalid_request_body", + msg, + map[string]any{"error": err.Error()}, + )) + return + } + + if err := validator.Validate(body); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, FailWithDetails( + "invalid_request_body", + "Request body does not match the declared tool schema", + map[string]any{"error": err.Error()}, + )) + return + } + + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + handler(c) + } +} + +func wrapToolResponseHandler(handler gin.HandlerFunc, validator *toolInputValidator) gin.HandlerFunc { + return func(c *gin.Context) { + recorder := newToolResponseRecorder(c.Writer) + c.Writer = recorder + + handler(c) + + if recorder.Status() >= 200 && recorder.Status() < 300 { + if err := validator.ValidateResponse(recorder.body.Bytes()); err != nil { + recorder.reset() + recorder.writeErrorResponse(http.StatusInternalServerError, FailWithDetails( + "invalid_tool_response", + "Tool response does not match the declared output schema", + map[string]any{"error": err.Error()}, + )) return } } + + recorder.commit() + } +} + +type toolInputValidator struct { + schema map[string]any +} + +func newToolInputValidator(schema map[string]any) *toolInputValidator { + if len(schema) == 0 { + return nil + } + return &toolInputValidator{schema: schema} +} + +func (v *toolInputValidator) Validate(body []byte) error { + if len(bytes.TrimSpace(body)) == 0 { + return coreerr.E("ToolBridge.Validate", "request body is required", nil) + } + + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + + var payload any + if err := dec.Decode(&payload); err != nil { + return coreerr.E("ToolBridge.Validate", "invalid JSON", err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return coreerr.E("ToolBridge.Validate", "request body must contain a single JSON value", nil) + } + + return validateSchemaNode(payload, v.schema, "") +} + +func (v *toolInputValidator) ValidateResponse(body []byte) error { + if len(v.schema) == 0 { + return nil + } + + // Use a decoder with UseNumber so that large integers in the envelope + // (including within the data field) are preserved as json.Number rather + // than being silently coerced to float64. This matches the behaviour of + // the Validate path and avoids precision loss for 64-bit integer values. + var envelope map[string]any + envDec := json.NewDecoder(bytes.NewReader(body)) + envDec.UseNumber() + if err := envDec.Decode(&envelope); err != nil { + return coreerr.E("ToolBridge.ValidateResponse", "invalid JSON response", err) + } + + success, _ := envelope["success"].(bool) + if !success { + return coreerr.E("ToolBridge.ValidateResponse", "response is missing a successful envelope", nil) + } + + // data is serialised with omitempty, so a nil/zero-value payload from + // constructors like OK(nil) or OK(false) will omit the key entirely. + // Treat a missing data key as a valid nil payload for successful responses. + data, ok := envelope["data"] + if !ok { + return nil + } + + encoded, err := json.Marshal(data) + if err != nil { + return coreerr.E("ToolBridge.ValidateResponse", "encode response data", err) + } + + var payload any + dec := json.NewDecoder(bytes.NewReader(encoded)) + dec.UseNumber() + if err := dec.Decode(&payload); err != nil { + return coreerr.E("ToolBridge.ValidateResponse", "decode response data", err) + } + + return validateSchemaNode(payload, v.schema, "") +} + +func validateSchemaNode(value any, schema map[string]any, path string) error { + if len(schema) == 0 { + return nil + } + + schemaType, _ := schema["type"].(string) + if schemaType != "" { + switch schemaType { + case "object": + obj, ok := value.(map[string]any) + if !ok { + return typeError(path, "object", value) + } + + for _, name := range stringList(schema["required"]) { + if _, ok := obj[name]; !ok { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s is missing required field %q", displayPath(path), name), nil) + } + } + + for name, rawChild := range schemaMap(schema["properties"]) { + childSchema, ok := rawChild.(map[string]any) + if !ok { + continue + } + childValue, ok := obj[name] + if !ok { + continue + } + if err := validateSchemaNode(childValue, childSchema, joinPath(path, name)); err != nil { + return err + } + } + + if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties { + properties := schemaMap(schema["properties"]) + for name := range obj { + if properties != nil { + if _, ok := properties[name]; ok { + continue + } + } + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s contains unknown field %q", displayPath(path), name), nil) + } + } + if err := validateObjectConstraints(obj, schema, path); err != nil { + return err + } + case "array": + arr, ok := value.([]any) + if !ok { + return typeError(path, "array", value) + } + if items := schemaMap(schema["items"]); len(items) > 0 { + for i, item := range arr { + if err := validateSchemaNode(item, items, joinPath(path, strconv.Itoa(i))); err != nil { + return err + } + } + } + if err := validateArrayConstraints(arr, schema, path); err != nil { + return err + } + case "string": + str, ok := value.(string) + if !ok { + return typeError(path, "string", value) + } + if err := validateStringConstraints(str, schema, path); err != nil { + return err + } + case "boolean": + if _, ok := value.(bool); !ok { + return typeError(path, "boolean", value) + } + case "integer": + if !isIntegerValue(value) { + return typeError(path, "integer", value) + } + if err := validateNumericConstraints(value, schema, path); err != nil { + return err + } + case "number": + if !isNumberValue(value) { + return typeError(path, "number", value) + } + if err := validateNumericConstraints(value, schema, path); err != nil { + return err + } + } + } + + if schemaType == "" && (len(schemaMap(schema["properties"])) > 0 || schema["required"] != nil || schema["additionalProperties"] != nil) { + props := schemaMap(schema["properties"]) + return validateSchemaNode(value, map[string]any{ + "type": "object", + "properties": props, + "required": schema["required"], + "additionalProperties": schema["additionalProperties"], + }, path) + } + + if rawEnum, ok := schema["enum"]; ok { + if !enumContains(value, rawEnum) { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be one of the declared enum values", displayPath(path)), nil) + } + } + + if err := validateSchemaCombinators(value, schema, path); err != nil { + return err + } + + return nil +} + +func validateSchemaCombinators(value any, schema map[string]any, path string) error { + if subschemas := schemaObjects(schema["allOf"]); len(subschemas) > 0 { + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err != nil { + return err + } + } + } + + if subschemas := schemaObjects(schema["anyOf"]); len(subschemas) > 0 { + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err == nil { + goto anyOfMatched + } + } + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must match at least one schema in anyOf", displayPath(path)), nil) + } + +anyOfMatched: + if subschemas := schemaObjects(schema["oneOf"]); len(subschemas) > 0 { + matches := 0 + for _, subschema := range subschemas { + if err := validateSchemaNode(value, subschema, path); err == nil { + matches++ + } + } + if matches != 1 { + if matches == 0 { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must match exactly one schema in oneOf", displayPath(path)), nil) + } + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s matches multiple schemas in oneOf", displayPath(path)), nil) + } + } + + if subschema, ok := schema["not"].(map[string]any); ok && subschema != nil { + if err := validateSchemaNode(value, subschema, path); err == nil { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must not match the forbidden schema", displayPath(path)), nil) + } + } + + return nil +} + +func validateStringConstraints(value string, schema map[string]any, path string) error { + length := utf8.RuneCountInString(value) + if minLength, ok := schemaInt(schema["minLength"]); ok && length < minLength { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be at least %d characters long", displayPath(path), minLength), nil) + } + if maxLength, ok := schemaInt(schema["maxLength"]); ok && length > maxLength { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be at most %d characters long", displayPath(path), maxLength), nil) + } + if pattern, ok := schema["pattern"].(string); ok && pattern != "" { + re, err := regexp.Compile(pattern) + if err != nil { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s has an invalid pattern %q", displayPath(path), pattern), err) + } + if !re.MatchString(value) { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s does not match pattern %q", displayPath(path), pattern), nil) + } + } + return nil +} + +func validateNumericConstraints(value any, schema map[string]any, path string) error { + if minimum, ok := schemaFloat(schema["minimum"]); ok && numericLessThan(value, minimum) { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be greater than or equal to %v", displayPath(path), minimum), nil) + } + if maximum, ok := schemaFloat(schema["maximum"]); ok && numericGreaterThan(value, maximum) { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be less than or equal to %v", displayPath(path), maximum), nil) + } + return nil +} + +func validateArrayConstraints(value []any, schema map[string]any, path string) error { + if minItems, ok := schemaInt(schema["minItems"]); ok && len(value) < minItems { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must contain at least %d items", displayPath(path), minItems), nil) + } + if maxItems, ok := schemaInt(schema["maxItems"]); ok && len(value) > maxItems { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must contain at most %d items", displayPath(path), maxItems), nil) + } + return nil +} + +func validateObjectConstraints(value map[string]any, schema map[string]any, path string) error { + if minProps, ok := schemaInt(schema["minProperties"]); ok && len(value) < minProps { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must contain at least %d properties", displayPath(path), minProps), nil) + } + if maxProps, ok := schemaInt(schema["maxProperties"]); ok && len(value) > maxProps { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must contain at most %d properties", displayPath(path), maxProps), nil) + } + return nil +} + +func schemaInt(value any) (int, bool) { + switch v := value.(type) { + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + return int(v), true + case uint: + return int(v), true + case uint8: + return int(v), true + case uint16: + return int(v), true + case uint32: + return int(v), true + case uint64: + return int(v), true + case float64: + if v == float64(int(v)) { + return int(v), true + } + case json.Number: + if n, err := v.Int64(); err == nil { + return int(n), true + } + } + return 0, false +} + +func schemaFloat(value any) (float64, bool) { + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case json.Number: + if n, err := v.Float64(); err == nil { + return n, true + } + } + return 0, false +} + +func numericLessThan(value any, limit float64) bool { + if n, ok := numericValue(value); ok { + return n < limit + } + return false +} + +func numericGreaterThan(value any, limit float64) bool { + if n, ok := numericValue(value); ok { + return n > limit + } + return false +} + +type toolResponseRecorder struct { + gin.ResponseWriter + headers http.Header + body bytes.Buffer + status int + wroteHeader bool +} + +func newToolResponseRecorder(w gin.ResponseWriter) *toolResponseRecorder { + headers := make(http.Header) + for k, vals := range w.Header() { + headers[k] = append([]string(nil), vals...) + } + return &toolResponseRecorder{ + ResponseWriter: w, + headers: headers, + status: http.StatusOK, + } +} + +func (w *toolResponseRecorder) Header() http.Header { + return w.headers +} + +func (w *toolResponseRecorder) WriteHeader(code int) { + w.status = code + w.wroteHeader = true +} + +func (w *toolResponseRecorder) WriteHeaderNow() { + w.wroteHeader = true +} + +func (w *toolResponseRecorder) Write(data []byte) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.body.Write(data) +} + +func (w *toolResponseRecorder) WriteString(s string) (int, error) { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.body.WriteString(s) +} + +func (w *toolResponseRecorder) Flush() { +} + +func (w *toolResponseRecorder) Status() int { + if w.wroteHeader { + return w.status + } + return http.StatusOK +} + +func (w *toolResponseRecorder) Size() int { + return w.body.Len() +} + +func (w *toolResponseRecorder) Written() bool { + return w.wroteHeader +} + +func (w *toolResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, coreerr.E("ToolBridge.ResponseRecorder", "response hijacking is not supported by ToolBridge output validation", nil) +} + +func (w *toolResponseRecorder) commit() { + for k := range w.ResponseWriter.Header() { + w.ResponseWriter.Header().Del(k) + } + for k, vals := range w.headers { + for _, v := range vals { + w.ResponseWriter.Header().Add(k, v) + } + } + w.ResponseWriter.WriteHeader(w.Status()) + _, _ = w.ResponseWriter.Write(w.body.Bytes()) +} + +func (w *toolResponseRecorder) reset() { + w.headers = make(http.Header) + w.body.Reset() + w.status = http.StatusInternalServerError + w.wroteHeader = false +} + +func (w *toolResponseRecorder) writeErrorResponse(status int, resp Response[any]) { + data, err := json.Marshal(resp) + if err != nil { + w.status = http.StatusInternalServerError + w.wroteHeader = true + http.Error(w.ResponseWriter, "internal server error", http.StatusInternalServerError) + return + } + + // Keep recorder state aligned with the replacement response so that + // Status(), Written(), Header() and Size() all reflect the error + // response. Post-handler middleware and metrics must observe correct + // values, not stale state from the reset() call above. + w.status = status + w.wroteHeader = true + if w.headers == nil { + w.headers = make(http.Header) + } + w.headers.Set("Content-Type", "application/json") + w.body.Reset() + _, _ = w.body.Write(data) + w.commit() +} + +func typeError(path, want string, value any) error { + return coreerr.E("ToolBridge.ValidateSchema", fmt.Sprintf("%s must be %s, got %s", displayPath(path), want, describeJSONValue(value)), nil) +} + +func displayPath(path string) string { + if path == "" { + return "request body" + } + return "request body." + path +} + +func joinPath(parent, child string) string { + if parent == "" { + return child + } + return parent + "." + child +} + +func schemaMap(value any) map[string]any { + if value == nil { + return nil + } + m, _ := value.(map[string]any) + return m +} + +func schemaObjects(value any) []map[string]any { + switch raw := value.(type) { + case []any: + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + if schema := schemaMap(item); schema != nil { + out = append(out, schema) + } + } + return out + case []map[string]any: + return append([]map[string]any(nil), raw...) + default: + return nil + } +} + +func stringList(value any) []string { + switch raw := value.(type) { + case []any: + out := make([]string, 0, len(raw)) + for _, item := range raw { + name, ok := item.(string) + if !ok { + continue + } + out = append(out, name) + } + return out + case []string: + return append([]string(nil), raw...) + default: + return nil + } +} + +func isIntegerValue(value any) bool { + switch v := value.(type) { + case json.Number: + _, err := v.Int64() + return err == nil + case float64: + return v == float64(int64(v)) + default: + return false + } +} + +func isNumberValue(value any) bool { + switch value.(type) { + case json.Number, float64: + return true + default: + return false + } +} + +func enumContains(value any, rawEnum any) bool { + items := enumValues(rawEnum) + for _, candidate := range items { + if valuesEqual(value, candidate) { + return true + } + } + return false +} + +func enumValues(rawEnum any) []any { + switch values := rawEnum.(type) { + case []any: + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out + case []string: + out := make([]any, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + return out + default: + return nil + } +} + +func valuesEqual(left, right any) bool { + if isNumericValue(left) && isNumericValue(right) { + lv, lok := numericValue(left) + rv, rok := numericValue(right) + return lok && rok && lv == rv + } + return reflect.DeepEqual(left, right) +} + +func isNumericValue(value any) bool { + switch value.(type) { + case json.Number, float64, float32, int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64: + return true + default: + return false + } +} + +func numericValue(value any) (float64, bool) { + switch v := value.(type) { + case json.Number: + n, err := v.Float64() + return n, err == nil + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + default: + return 0, false + } +} + +func describeJSONValue(value any) string { + switch value.(type) { + case nil: + return "null" + case string: + return "string" + case bool: + return "boolean" + case json.Number, float64: + return "number" + case map[string]any: + return "object" + case []any: + return "array" + default: + return fmt.Sprintf("%T", value) } } diff --git a/bridge_test.go b/bridge_test.go index 3c5c6c4..3b26b1f 100644 --- a/bridge_test.go +++ b/bridge_test.go @@ -3,6 +3,7 @@ package api_test import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -153,6 +154,522 @@ func TestToolBridge_Good_Describe(t *testing.T) { } } +func TestToolBridge_Good_DescribeTrimsBlankGroup(t *testing.T) { + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: " ", + }, func(c *gin.Context) {}) + + descs := bridge.Describe() + if len(descs) != 1 { + t.Fatalf("expected 1 description, got %d", len(descs)) + } + if len(descs[0].Tags) != 1 || descs[0].Tags[0] != "tools" { + t.Fatalf("expected blank group to fall back to bridge tag, got %v", descs[0].Tags) + } +} + +func TestToolBridge_Good_ValidatesRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + var payload map[string]any + if err := json.NewDecoder(c.Request.Body).Decode(&payload); err != nil { + t.Fatalf("handler could not read validated body: %v", err) + } + c.JSON(http.StatusOK, api.OK(payload["path"])) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":"/tmp/file.txt"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data != "/tmp/file.txt" { + t.Fatalf("expected validated payload to reach handler, got %q", resp.Data) + } +} + +func TestToolBridge_Good_ValidatesResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + OutputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK(map[string]any{"path": "/tmp/file.txt"})) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[map[string]any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if !resp.Success { + t.Fatal("expected Success=true") + } + if resp.Data["path"] != "/tmp/file.txt" { + t.Fatalf("expected validated response data to reach client, got %v", resp.Data["path"]) + } +} + +func TestToolBridge_Bad_InvalidResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + OutputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK(map[string]any{"path": 123})) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", nil) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected 500, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_tool_response" { + t.Fatalf("expected invalid_tool_response error, got %#v", resp.Error) + } +} + +func TestToolBridge_Bad_InvalidRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "file_read", + Description: "Read a file from disk", + Group: "files", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []any{"path"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("should not run")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/file_read", bytes.NewBufferString(`{"path":123}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + +func TestToolBridge_Good_ValidatesEnumValues(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{ + "type": "string", + "enum": []any{"draft", "published"}, + }, + }, + "required": []any{"status"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsInvalidEnumValues(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{ + "type": "string", + "enum": []any{"draft", "published"}, + }, + }, + "required": []any{"status"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"archived"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + +func TestToolBridge_Good_ValidatesSchemaCombinators(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "route_choice", + Description: "Choose a route", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "choice": map[string]any{ + "oneOf": []any{ + map[string]any{ + "type": "string", + "allOf": []any{ + map[string]any{"minLength": 2}, + map[string]any{"pattern": "^[A-Z]+$"}, + }, + }, + map[string]any{ + "type": "string", + "pattern": "^A", + }, + }, + }, + }, + "required": []any{"choice"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"BC"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsAmbiguousOneOfMatches(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "route_choice", + Description: "Choose a route", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "choice": map[string]any{ + "oneOf": []any{ + map[string]any{ + "type": "string", + "allOf": []any{ + map[string]any{"minLength": 1}, + map[string]any{"pattern": "^[A-Z]+$"}, + }, + }, + map[string]any{ + "type": "string", + "pattern": "^A", + }, + }, + }, + }, + "required": []any{"choice"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/route_choice", bytes.NewBufferString(`{"choice":"A"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + +func TestToolBridge_Bad_RejectsAdditionalProperties(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_item", + Description: "Publish an item", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "status": map[string]any{"type": "string"}, + }, + "required": []any{"status"}, + "additionalProperties": false, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("published")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_item", bytes.NewBufferString(`{"status":"published","unexpected":true}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + +func TestToolBridge_Good_EnforcesStringConstraints(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "publish_code", + Description: "Publish a code", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "code": map[string]any{ + "type": "string", + "minLength": 3, + "maxLength": 5, + "pattern": "^[A-Z]+$", + }, + }, + "required": []any{"code"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/publish_code", bytes.NewBufferString(`{"code":"ABC"}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } +} + +func TestToolBridge_Bad_RejectsNumericAndCollectionConstraints(t *testing.T) { + gin.SetMode(gin.TestMode) + engine := gin.New() + + bridge := api.NewToolBridge("/tools") + bridge.Add(api.ToolDescriptor{ + Name: "quota_check", + Description: "Check quotas", + Group: "items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "count": map[string]any{ + "type": "integer", + "minimum": 1, + "maximum": 3, + }, + "labels": map[string]any{ + "type": "array", + "minItems": 2, + "maxItems": 4, + "items": map[string]any{ + "type": "string", + }, + }, + "payload": map[string]any{ + "type": "object", + "minProperties": 1, + "maxProperties": 2, + "additionalProperties": true, + }, + }, + "required": []any{"count", "labels", "payload"}, + }, + }, func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("accepted")) + }) + + rg := engine.Group(bridge.BasePath()) + bridge.RegisterRoutes(rg) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodPost, "/tools/quota_check", bytes.NewBufferString(`{"count":0,"labels":["one"],"payload":{}}`)) + engine.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for numeric/collection constraint failure, got %d", w.Code) + } + + var resp api.Response[any] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false") + } + if resp.Error == nil || resp.Error.Code != "invalid_request_body" { + t.Fatalf("expected invalid_request_body error, got %#v", resp.Error) + } +} + func TestToolBridge_Good_ToolsAccessor(t *testing.T) { bridge := api.NewToolBridge("/tools") bridge.Add(api.ToolDescriptor{Name: "alpha", Description: "Tool A", Group: "a"}, func(c *gin.Context) {}) diff --git a/cache.go b/cache.go index d032346..19d6bea 100644 --- a/cache.go +++ b/cache.go @@ -4,8 +4,10 @@ package api import ( "bytes" + "container/list" "maps" "net/http" + "strconv" "sync" "time" @@ -17,48 +19,151 @@ type cacheEntry struct { status int headers http.Header body []byte + size int expires time.Time } // cacheStore is a simple thread-safe in-memory cache keyed by request URL. type cacheStore struct { - mu sync.RWMutex - entries map[string]*cacheEntry + mu sync.RWMutex + entries map[string]*cacheEntry + order *list.List + index map[string]*list.Element + maxEntries int + maxBytes int + currentBytes int } // newCacheStore creates an empty cache store. -func newCacheStore() *cacheStore { +func newCacheStore(maxEntries, maxBytes int) *cacheStore { return &cacheStore{ - entries: make(map[string]*cacheEntry), + entries: make(map[string]*cacheEntry), + order: list.New(), + index: make(map[string]*list.Element), + maxEntries: maxEntries, + maxBytes: maxBytes, } } // get retrieves a non-expired entry for the given key. // Returns nil if the key is missing or expired. func (s *cacheStore) get(key string) *cacheEntry { - s.mu.RLock() + s.mu.Lock() entry, ok := s.entries[key] - s.mu.RUnlock() - if !ok { + s.mu.Unlock() return nil } + + // Check expiry before promoting in the LRU order so we never move a stale + // entry to the front. All expiry checking and eviction happen inside the + // same critical section to avoid a TOCTOU race. if time.Now().After(entry.expires) { - s.mu.Lock() + if elem, exists := s.index[key]; exists { + s.order.Remove(elem) + delete(s.index, key) + } + s.currentBytes -= entry.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } delete(s.entries, key) s.mu.Unlock() return nil } + + // Only promote to LRU front after confirming the entry is still valid. + if elem, exists := s.index[key]; exists { + s.order.MoveToFront(elem) + } + s.mu.Unlock() return entry } // set stores a cache entry with the given TTL. func (s *cacheStore) set(key string, entry *cacheEntry) { s.mu.Lock() + if entry.size <= 0 { + entry.size = cacheEntrySize(entry.headers, entry.body) + } + + if elem, ok := s.index[key]; ok { + // Reject an oversized replacement before touching LRU state so the + // existing entry remains intact when the new value cannot fit. + if s.maxBytes > 0 && entry.size > s.maxBytes { + s.mu.Unlock() + return + } + if existing, exists := s.entries[key]; exists { + s.currentBytes -= existing.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } + } + s.order.MoveToFront(elem) + s.entries[key] = entry + s.currentBytes += entry.size + s.evictBySizeLocked() + s.mu.Unlock() + return + } + + if s.maxBytes > 0 && entry.size > s.maxBytes { + s.mu.Unlock() + return + } + + for (s.maxEntries > 0 && len(s.entries) >= s.maxEntries) || s.wouldExceedBytesLocked(entry.size) { + if !s.evictOldestLocked() { + break + } + } + + if s.maxBytes > 0 && s.wouldExceedBytesLocked(entry.size) { + s.mu.Unlock() + return + } + s.entries[key] = entry + elem := s.order.PushFront(key) + s.index[key] = elem + s.currentBytes += entry.size s.mu.Unlock() } +func (s *cacheStore) wouldExceedBytesLocked(nextSize int) bool { + if s.maxBytes <= 0 { + return false + } + return s.currentBytes+nextSize > s.maxBytes +} + +func (s *cacheStore) evictBySizeLocked() { + for s.maxBytes > 0 && s.currentBytes > s.maxBytes { + if !s.evictOldestLocked() { + return + } + } +} + +func (s *cacheStore) evictOldestLocked() bool { + back := s.order.Back() + if back == nil { + return false + } + oldKey := back.Value.(string) + if existing, ok := s.entries[oldKey]; ok { + s.currentBytes -= existing.size + if s.currentBytes < 0 { + s.currentBytes = 0 + } + } + delete(s.entries, oldKey) + delete(s.index, oldKey) + s.order.Remove(back) + return true +} + // cacheWriter intercepts writes to capture the response body and status. type cacheWriter struct { gin.ResponseWriter @@ -89,14 +194,51 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { // Serve from cache if a valid entry exists. if entry := store.get(key); entry != nil { + body := entry.body + metaRewritten := false + if meta := GetRequestMeta(c); meta != nil { + body = refreshCachedResponseMeta(entry.body, meta) + metaRewritten = true + } + + // staleValidatorHeader returns true for headers that describe the + // exact bytes of the cached body and must be dropped when the body + // has been rewritten by refreshCachedResponseMeta. + staleValidatorHeader := func(canonical string) bool { + if !metaRewritten { + return false + } + switch canonical { + case "Etag", "Content-Md5", "Digest": + return true + } + return false + } + for k, vals := range entry.headers { + canonical := http.CanonicalHeaderKey(k) + if canonical == "X-Request-Id" { + continue + } + if canonical == "Content-Length" { + continue + } + if staleValidatorHeader(canonical) { + continue + } for _, v := range vals { - c.Writer.Header().Set(k, v) + c.Writer.Header().Add(k, v) } } + if requestID := GetRequestID(c); requestID != "" { + c.Writer.Header().Set("X-Request-ID", requestID) + } else if requestID := c.GetHeader("X-Request-ID"); requestID != "" { + c.Writer.Header().Set("X-Request-ID", requestID) + } c.Writer.Header().Set("X-Cache", "HIT") + c.Writer.Header().Set("Content-Length", strconv.Itoa(len(body))) c.Writer.WriteHeader(entry.status) - _, _ = c.Writer.Write(entry.body) + _, _ = c.Writer.Write(body) c.Abort() return } @@ -119,8 +261,28 @@ func cacheMiddleware(store *cacheStore, ttl time.Duration) gin.HandlerFunc { status: status, headers: headers, body: cw.body.Bytes(), + size: cacheEntrySize(headers, cw.body.Bytes()), expires: time.Now().Add(ttl), }) } } } + +// refreshCachedResponseMeta updates the meta envelope in a cached JSON body so +// request-scoped metadata reflects the current request instead of the cache fill. +// Non-JSON bodies, malformed JSON, and responses without a top-level object are +// returned unchanged. +func refreshCachedResponseMeta(body []byte, meta *Meta) []byte { + return refreshResponseMetaBody(body, meta) +} + +func cacheEntrySize(headers http.Header, body []byte) int { + size := len(body) + for key, vals := range headers { + size += len(key) + for _, val := range vals { + size += len(val) + } + } + return size +} diff --git a/cache_config.go b/cache_config.go new file mode 100644 index 0000000..726ee0c --- /dev/null +++ b/cache_config.go @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import "time" + +// CacheConfig captures the configured response cache settings for an Engine. +// +// It is intentionally small and serialisable so callers can inspect the active +// cache policy without needing to rebuild middleware state. +// +// Example: +// +// cfg := api.CacheConfig{Enabled: true, TTL: 5 * time.Minute} +type CacheConfig struct { + Enabled bool + TTL time.Duration + MaxEntries int + MaxBytes int +} + +// CacheConfig returns the currently configured response cache settings for the engine. +// +// The result snapshots the Engine state at call time. +// +// Example: +// +// cfg := engine.CacheConfig() +func (e *Engine) CacheConfig() CacheConfig { + if e == nil { + return CacheConfig{} + } + + cfg := CacheConfig{ + TTL: e.cacheTTL, + MaxEntries: e.cacheMaxEntries, + MaxBytes: e.cacheMaxBytes, + } + if e.cacheTTL > 0 { + cfg.Enabled = true + } + return cfg +} diff --git a/cache_test.go b/cache_test.go index 58820c3..4971240 100644 --- a/cache_test.go +++ b/cache_test.go @@ -40,6 +40,23 @@ func (g *cacheCounterGroup) RegisterRoutes(rg *gin.RouterGroup) { }) } +type cacheSizedGroup struct { + counter atomic.Int64 +} + +func (g *cacheSizedGroup) Name() string { return "cache-sized" } +func (g *cacheSizedGroup) BasePath() string { return "/cache" } +func (g *cacheSizedGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/small", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("small-%d-%s", n, strings.Repeat("a", 96)))) + }) + rg.GET("/large", func(c *gin.Context) { + n := g.counter.Add(1) + c.JSON(http.StatusOK, api.OK(fmt.Sprintf("large-%d-%s", n, strings.Repeat("b", 96)))) + }) +} + // ── WithCache ─────────────────────────────────────────────────────────── func TestWithCache_Good_CachesGETResponse(t *testing.T) { @@ -89,6 +106,36 @@ func TestWithCache_Good_CachesGETResponse(t *testing.T) { } } +func TestWithCacheLimits_Good_CachesGETResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCacheLimits(5*time.Second, 1, 0)) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w2.Code) + } + + if got := w2.Header().Get("X-Cache"); got != "HIT" { + t.Fatalf("expected X-Cache=HIT, got %q", got) + } + if grp.counter.Load() != 1 { + t.Fatalf("expected counter=1 (cached), got %d", grp.counter.Load()) + } +} + func TestWithCache_Good_POSTNotCached(t *testing.T) { gin.SetMode(gin.TestMode) grp := &cacheCounterGroup{} @@ -214,6 +261,189 @@ func TestWithCache_Good_CombinesWithOtherMiddleware(t *testing.T) { } } +func TestWithCache_Good_PreservesCurrentRequestIDOnHit(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New( + api.WithRequestID(), + api.WithCache(5*time.Second), + ) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + req1.Header.Set("X-Request-ID", "first-request-id") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + if got := w1.Header().Get("X-Request-ID"); got != "first-request-id" { + t.Fatalf("expected first response request ID %q, got %q", "first-request-id", got) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + req2.Header.Set("X-Request-ID", "second-request-id") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w2.Code) + } + + if got := w2.Header().Get("X-Request-ID"); got != "second-request-id" { + t.Fatalf("expected cached response to preserve current request ID %q, got %q", "second-request-id", got) + } + if got := w2.Header().Get("X-Cache"); got != "HIT" { + t.Fatalf("expected X-Cache=HIT, got %q", got) + } + + var resp2 api.Response[string] + if err := json.Unmarshal(w2.Body.Bytes(), &resp2); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp2.Data != "call-1" { + t.Fatalf("expected cached response data %q, got %q", "call-1", resp2.Data) + } + if resp2.Meta == nil { + t.Fatal("expected cached response meta to be attached") + } + if resp2.Meta.RequestID != "second-request-id" { + t.Fatalf("expected cached response request_id=%q, got %q", "second-request-id", resp2.Meta.RequestID) + } + if resp2.Meta.Duration == "" { + t.Fatal("expected cached response duration to be refreshed") + } +} + +func TestWithCache_Good_PreservesCurrentRequestMetaOnHit(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithRequestID(), + api.WithCache(5*time.Second), + ) + e.Register(requestMetaTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/v1/meta", nil) + req1.Header.Set("X-Request-ID", "first-request-id") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + + var resp1 api.Response[string] + if err := json.Unmarshal(w1.Body.Bytes(), &resp1); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp1.Meta == nil { + t.Fatal("expected meta on first response") + } + if resp1.Meta.RequestID != "first-request-id" { + t.Fatalf("expected first response request_id=%q, got %q", "first-request-id", resp1.Meta.RequestID) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/v1/meta", nil) + req2.Header.Set("X-Request-ID", "second-request-id") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w2.Code) + } + + var resp2 api.Response[string] + if err := json.Unmarshal(w2.Body.Bytes(), &resp2); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp2.Meta == nil { + t.Fatal("expected meta on cached response") + } + if resp2.Meta.RequestID != "second-request-id" { + t.Fatalf("expected cached response request_id=%q, got %q", "second-request-id", resp2.Meta.RequestID) + } + if resp2.Meta.Duration == "" { + t.Fatal("expected cached response duration to be refreshed") + } + if resp2.Meta.Page != 1 || resp2.Meta.PerPage != 25 || resp2.Meta.Total != 100 { + t.Fatalf("expected pagination metadata to remain intact, got %+v", resp2.Meta) + } + if got := w2.Header().Get("X-Request-ID"); got != "second-request-id" { + t.Fatalf("expected response header X-Request-ID=%q, got %q", "second-request-id", got) + } +} + +type cacheHeaderGroup struct{} + +func (cacheHeaderGroup) Name() string { return "cache-headers" } +func (cacheHeaderGroup) BasePath() string { return "/cache" } +func (cacheHeaderGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/multi", func(c *gin.Context) { + c.Writer.Header().Add("Link", "; rel=\"next\"") + c.Writer.Header().Add("Link", "; rel=\"prev\"") + c.JSON(http.StatusOK, api.OK("cached")) + }) +} + +func TestWithCache_Good_PreservesMultiValueHeadersOnHit(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithCache(5 * time.Second)) + e.Register(cacheHeaderGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/multi", nil) + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/multi", nil) + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected 200 on cache hit, got %d", w2.Code) + } + + linkHeaders := w2.Header().Values("Link") + if len(linkHeaders) != 2 { + t.Fatalf("expected 2 Link headers on cache hit, got %v", linkHeaders) + } + if linkHeaders[0] != "; rel=\"next\"" { + t.Fatalf("expected first Link header to be preserved, got %q", linkHeaders[0]) + } + if linkHeaders[1] != "; rel=\"prev\"" { + t.Fatalf("expected second Link header to be preserved, got %q", linkHeaders[1]) + } +} + +func TestWithCache_Ugly_NonPositiveTTLDisablesMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(0)) + e.Register(grp) + + h := e.Handler() + + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected request %d to succeed with disabled cache, got %d", i+1, w.Code) + } + if got := w.Header().Get("X-Cache"); got != "" { + t.Fatalf("expected no X-Cache header with disabled cache, got %q", got) + } + } + + if grp.counter.Load() != 2 { + t.Fatalf("expected counter=2 with disabled cache, got %d", grp.counter.Load()) + } +} + func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) { gin.SetMode(gin.TestMode) grp := &cacheCounterGroup{} @@ -250,3 +480,75 @@ func TestWithCache_Good_ExpiredCacheMisses(t *testing.T) { t.Fatalf("expected counter=2, got %d", grp.counter.Load()) } } + +func TestWithCache_Good_EvictsWhenCapacityReached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheCounterGroup{} + e, _ := api.New(api.WithCache(5*time.Second, 1)) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w1, req1) + if !strings.Contains(w1.Body.String(), "call-1") { + t.Fatalf("expected first response to contain %q, got %q", "call-1", w1.Body.String()) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/other", nil) + h.ServeHTTP(w2, req2) + if !strings.Contains(w2.Body.String(), "other-2") { + t.Fatalf("expected second response to contain %q, got %q", "other-2", w2.Body.String()) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/cache/counter", nil) + h.ServeHTTP(w3, req3) + if !strings.Contains(w3.Body.String(), "call-3") { + t.Fatalf("expected evicted response to contain %q, got %q", "call-3", w3.Body.String()) + } + + if grp.counter.Load() != 3 { + t.Fatalf("expected counter=3 after eviction, got %d", grp.counter.Load()) + } +} + +func TestWithCache_Good_EvictsWhenSizeLimitReached(t *testing.T) { + gin.SetMode(gin.TestMode) + grp := &cacheSizedGroup{} + e, _ := api.New(api.WithCacheLimits(5*time.Second, 10, 250)) + e.Register(grp) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/cache/small", nil) + h.ServeHTTP(w1, req1) + if !strings.Contains(w1.Body.String(), "small-1") { + t.Fatalf("expected first response to contain %q, got %q", "small-1", w1.Body.String()) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/cache/large", nil) + h.ServeHTTP(w2, req2) + if !strings.Contains(w2.Body.String(), "large-2") { + t.Fatalf("expected second response to contain %q, got %q", "large-2", w2.Body.String()) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/cache/small", nil) + h.ServeHTTP(w3, req3) + if !strings.Contains(w3.Body.String(), "small-3") { + t.Fatalf("expected size-limited cache to evict the oldest entry, got %q", w3.Body.String()) + } + + if got := w3.Header().Get("X-Cache"); got != "" { + t.Fatalf("expected re-executed response to miss the cache, got X-Cache=%q", got) + } + + if grp.counter.Load() != 3 { + t.Fatalf("expected counter=3 after size-based eviction, got %d", grp.counter.Load()) + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..949e860 --- /dev/null +++ b/client.go @@ -0,0 +1,1038 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "iter" + "net/http" + "net/url" + "os" + "reflect" + "sort" + "strings" + "sync" + + "slices" + + "gopkg.in/yaml.v3" + + coreerr "dappco.re/go/core/log" +) + +// OpenAPIClient is a small runtime client that can call operations by their +// OpenAPI operationId. It loads the spec once, resolves the HTTP method and +// path for each operation, and performs JSON request/response handling. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithSpec("./openapi.yaml"), api.WithBaseURL("https://api.example.com")) +// data, err := client.Call("get_health", nil) +type OpenAPIClient struct { + specPath string + specReader io.Reader + baseURL string + bearerToken string + httpClient *http.Client + + once sync.Once + operations map[string]openAPIOperation + servers []string + loadErr error +} + +// OpenAPIOperation snapshots the public metadata for a single loaded OpenAPI +// operation. +// +// Example: +// +// ops, err := client.Operations() +// if err == nil && len(ops) > 0 { +// fmt.Println(ops[0].OperationID, ops[0].PathTemplate) +// } +type OpenAPIOperation struct { + OperationID string + Method string + PathTemplate string + HasRequestBody bool + Parameters []OpenAPIParameter +} + +// OpenAPIParameter snapshots a single OpenAPI parameter definition. +// +// Example: +// +// op, err := client.Operations() +// if err == nil && len(op) > 0 && len(op[0].Parameters) > 0 { +// fmt.Println(op[0].Parameters[0].Name, op[0].Parameters[0].In) +// } +type OpenAPIParameter struct { + Name string + In string + Required bool + Schema map[string]any +} + +type openAPIOperation struct { + method string + pathTemplate string + hasRequestBody bool + parameters []openAPIParameter + requestSchema map[string]any + responseSchema map[string]any +} + +type openAPIParameter struct { + name string + in string + required bool + schema map[string]any +} + +// OpenAPIClientOption configures a runtime OpenAPI client. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithSpec("./openapi.yaml")) +type OpenAPIClientOption func(*OpenAPIClient) + +// WithSpec sets the filesystem path to the OpenAPI document. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithSpec("./openapi.yaml")) +func WithSpec(path string) OpenAPIClientOption { + return func(c *OpenAPIClient) { + c.specPath = path + } +} + +// WithSpecReader sets an in-memory or streamed OpenAPI document source. +// It is read once the first time the client loads its spec. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithSpecReader(strings.NewReader(spec))) +func WithSpecReader(reader io.Reader) OpenAPIClientOption { + return func(c *OpenAPIClient) { + c.specReader = reader + } +} + +// WithBaseURL sets the base URL used for outgoing requests. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithBaseURL("https://api.example.com")) +func WithBaseURL(baseURL string) OpenAPIClientOption { + return func(c *OpenAPIClient) { + c.baseURL = baseURL + } +} + +// WithBearerToken sets the Authorization bearer token used for requests. +// +// Example: +// +// client := api.NewOpenAPIClient( +// api.WithBaseURL("https://api.example.com"), +// api.WithBearerToken("secret-token"), +// ) +func WithBearerToken(token string) OpenAPIClientOption { + return func(c *OpenAPIClient) { + c.bearerToken = token + } +} + +// WithHTTPClient sets the HTTP client used to execute requests. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithHTTPClient(http.DefaultClient)) +func WithHTTPClient(client *http.Client) OpenAPIClientOption { + return func(c *OpenAPIClient) { + c.httpClient = client + } +} + +// NewOpenAPIClient constructs a runtime client for calling OpenAPI operations. +// +// Example: +// +// client := api.NewOpenAPIClient(api.WithSpec("./openapi.yaml")) +func NewOpenAPIClient(opts ...OpenAPIClientOption) *OpenAPIClient { + c := &OpenAPIClient{ + httpClient: http.DefaultClient, + } + for _, opt := range opts { + opt(c) + } + if c.httpClient == nil { + c.httpClient = http.DefaultClient + } + return c +} + +// Operations returns a snapshot of the operations loaded from the OpenAPI +// document. +// +// Example: +// +// ops, err := client.Operations() +func (c *OpenAPIClient) Operations() ([]OpenAPIOperation, error) { + if err := c.load(); err != nil { + return nil, err + } + + operations := make([]OpenAPIOperation, 0, len(c.operations)) + for operationID, op := range c.operations { + operations = append(operations, snapshotOpenAPIOperation(operationID, op)) + } + sort.SliceStable(operations, func(i, j int) bool { + if operations[i].OperationID == operations[j].OperationID { + if operations[i].Method == operations[j].Method { + return operations[i].PathTemplate < operations[j].PathTemplate + } + return operations[i].Method < operations[j].Method + } + return operations[i].OperationID < operations[j].OperationID + }) + return operations, nil +} + +// OperationsIter returns an iterator over the loaded OpenAPI operations. +// +// Example: +// +// ops, err := client.OperationsIter() +// if err != nil { +// panic(err) +// } +// for op := range ops { +// fmt.Println(op.OperationID, op.PathTemplate) +// } +func (c *OpenAPIClient) OperationsIter() (iter.Seq[OpenAPIOperation], error) { + operations, err := c.Operations() + if err != nil { + return nil, err + } + + return func(yield func(OpenAPIOperation) bool) { + for _, op := range operations { + if !yield(op) { + return + } + } + }, nil +} + +// Servers returns a snapshot of the server URLs discovered from the OpenAPI +// document. +// +// Example: +// +// servers, err := client.Servers() +func (c *OpenAPIClient) Servers() ([]string, error) { + if err := c.load(); err != nil { + return nil, err + } + + return slices.Clone(c.servers), nil +} + +// ServersIter returns an iterator over the server URLs discovered from the +// OpenAPI document. +// +// Example: +// +// servers, err := client.ServersIter() +// if err != nil { +// panic(err) +// } +// for server := range servers { +// fmt.Println(server) +// } +func (c *OpenAPIClient) ServersIter() (iter.Seq[string], error) { + servers, err := c.Servers() + if err != nil { + return nil, err + } + + return func(yield func(string) bool) { + for _, server := range servers { + if !yield(server) { + return + } + } + }, nil +} + +// Call invokes the operation with the given operationId. +// +// The params argument may be a map, struct, or nil. For convenience, a map may +// include "path", "query", "header", "cookie", and "body" keys to explicitly +// control where the values are sent. When no explicit body is provided, +// requests with a declared requestBody send the remaining parameters as JSON. +// +// Example: +// +// data, err := client.Call("create_item", map[string]any{"name": "alpha"}) +func (c *OpenAPIClient) Call(operationID string, params any) (any, error) { + if err := c.load(); err != nil { + return nil, err + } + if c.httpClient == nil { + c.httpClient = http.DefaultClient + } + + op, ok := c.operations[operationID] + if !ok { + return nil, coreerr.E("OpenAPIClient.Call", fmt.Sprintf("operation %q not found in OpenAPI spec", operationID), nil) + } + + merged, err := normaliseParams(params) + if err != nil { + return nil, err + } + + requestURL, err := c.buildURL(op, merged) + if err != nil { + return nil, err + } + + body, err := c.buildBody(op, merged) + if err != nil { + return nil, err + } + + if op.requestSchema != nil && len(body) > 0 { + if err := validateOpenAPISchema(body, op.requestSchema, "request body"); err != nil { + return nil, err + } + } + + var bodyReader io.Reader + if len(body) > 0 { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(op.method, requestURL, bodyReader) + if err != nil { + return nil, err + } + if bodyReader != nil { + req.Header.Set("Content-Type", "application/json") + } + if c.bearerToken != "" { + req.Header.Set("Authorization", "Bearer "+c.bearerToken) + } + applyRequestParameters(req, op, merged) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + payload, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, coreerr.E("OpenAPIClient.Call", fmt.Sprintf("openapi call %s returned %s: %s", operationID, resp.Status, strings.TrimSpace(string(payload))), nil) + } + + if op.responseSchema != nil && len(bytes.TrimSpace(payload)) > 0 { + if err := validateOpenAPIResponse(payload, op.responseSchema, operationID); err != nil { + return nil, err + } + } + + if len(bytes.TrimSpace(payload)) == 0 { + return nil, nil + } + + var decoded any + dec := json.NewDecoder(bytes.NewReader(payload)) + dec.UseNumber() + if err := dec.Decode(&decoded); err != nil { + return string(payload), nil + } + + if envelope, ok := decoded.(map[string]any); ok { + if success, ok := envelope["success"].(bool); ok { + if !success { + if errObj, ok := envelope["error"].(map[string]any); ok { + return nil, coreerr.E("OpenAPIClient.Call", fmt.Sprintf("openapi call %s failed: %v", operationID, errObj), nil) + } + return nil, coreerr.E("OpenAPIClient.Call", fmt.Sprintf("openapi call %s failed", operationID), nil) + } + if data, ok := envelope["data"]; ok { + return data, nil + } + } + } + + return decoded, nil +} + +func (c *OpenAPIClient) load() error { + c.once.Do(func() { + c.loadErr = c.loadSpec() + }) + return c.loadErr +} + +func (c *OpenAPIClient) loadSpec() error { + var ( + data []byte + err error + ) + + switch { + case c.specReader != nil: + data, err = io.ReadAll(c.specReader) + case c.specPath != "": + f, openErr := os.Open(c.specPath) + if openErr != nil { + return coreerr.E("OpenAPIClient.loadSpec", "read spec", openErr) + } + defer f.Close() + data, err = io.ReadAll(f) + default: + return coreerr.E("OpenAPIClient.loadSpec", "spec path or reader is required", nil) + } + + if err != nil { + return coreerr.E("OpenAPIClient.loadSpec", "read spec", err) + } + + var spec map[string]any + if err := yaml.Unmarshal(data, &spec); err != nil { + return coreerr.E("OpenAPIClient.loadSpec", "parse spec", err) + } + + operations := make(map[string]openAPIOperation) + if paths, ok := spec["paths"].(map[string]any); ok { + for pathTemplate, rawPathItem := range paths { + pathItem, ok := rawPathItem.(map[string]any) + if !ok { + continue + } + for method, rawOperation := range pathItem { + operation, ok := rawOperation.(map[string]any) + if !ok { + continue + } + operationID, _ := operation["operationId"].(string) + if operationID == "" { + continue + } + params := parseOperationParameters(operation) + operations[operationID] = openAPIOperation{ + method: strings.ToUpper(method), + pathTemplate: pathTemplate, + hasRequestBody: operation["requestBody"] != nil, + parameters: params, + requestSchema: requestBodySchema(operation), + responseSchema: firstSuccessResponseSchema(operation), + } + } + } + } + + c.operations = operations + if servers, ok := spec["servers"].([]any); ok { + for _, rawServer := range servers { + server, ok := rawServer.(map[string]any) + if !ok { + continue + } + if u, _ := server["url"].(string); u != "" { + c.servers = append(c.servers, u) + } + } + } + c.servers = normaliseServers(c.servers) + + if c.baseURL == "" { + for _, server := range c.servers { + if isAbsoluteBaseURL(server) { + c.baseURL = server + break + } + } + } + + return nil +} + +func snapshotOpenAPIOperation(operationID string, op openAPIOperation) OpenAPIOperation { + parameters := make([]OpenAPIParameter, len(op.parameters)) + for i, param := range op.parameters { + parameters[i] = OpenAPIParameter{ + Name: param.name, + In: param.in, + Required: param.required, + Schema: cloneOpenAPIObject(param.schema), + } + } + + return OpenAPIOperation{ + OperationID: operationID, + Method: strings.ToUpper(op.method), + PathTemplate: op.pathTemplate, + HasRequestBody: op.hasRequestBody, + Parameters: parameters, + } +} + +func (c *OpenAPIClient) buildURL(op openAPIOperation, params map[string]any) (string, error) { + base := strings.TrimRight(c.baseURL, "/") + if base == "" { + return "", coreerr.E("OpenAPIClient.buildURL", "base URL is required", nil) + } + + path := op.pathTemplate + pathKeys := pathParameterNames(path) + pathValues := map[string]any{} + if explicitPath, ok := nestedMap(params, "path"); ok { + pathValues = explicitPath + } else { + pathValues = params + } + + if err := validateRequiredParameters(op, params, pathKeys); err != nil { + return "", err + } + if err := validateParameterValues(op, params); err != nil { + return "", err + } + + for _, key := range pathKeys { + if value, ok := pathValues[key]; ok { + placeholder := "{" + key + "}" + path = strings.ReplaceAll(path, placeholder, url.PathEscape(fmt.Sprint(value))) + } + } + + if strings.Contains(path, "{") { + return "", coreerr.E("OpenAPIClient.buildURL", fmt.Sprintf("missing path parameters for %q", op.pathTemplate), nil) + } + + fullURL, err := url.JoinPath(base, path) + if err != nil { + return "", err + } + + query := url.Values{} + if explicitQuery, ok := nestedMap(params, "query"); ok { + for key, value := range explicitQuery { + appendQueryValue(query, key, value) + } + } + for key, value := range params { + if key == "path" || key == "body" || key == "query" || key == "header" || key == "cookie" { + continue + } + if containsString(pathKeys, key) { + continue + } + location := operationParameterLocation(op, key) + if location != "query" && !(location == "" && (op.method == http.MethodGet || (op.method == http.MethodHead && !op.hasRequestBody))) { + continue + } + if _, exists := query[key]; exists { + continue + } + appendQueryValue(query, key, value) + } + + if encoded := query.Encode(); encoded != "" { + fullURL += "?" + encoded + } + + return fullURL, nil +} + +func (c *OpenAPIClient) buildBody(op openAPIOperation, params map[string]any) ([]byte, error) { + if explicitBody, ok := params["body"]; ok { + return encodeJSONBody(explicitBody) + } + + if op.method == http.MethodGet || (op.method == http.MethodHead && !op.hasRequestBody) { + return nil, nil + } + + if len(params) == 0 { + return nil, nil + } + + pathKeys := pathParameterNames(op.pathTemplate) + queryKeys := map[string]struct{}{} + if explicitQuery, ok := nestedMap(params, "query"); ok { + for key := range explicitQuery { + queryKeys[key] = struct{}{} + } + } + + payload := make(map[string]any, len(params)) + for key, value := range params { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { + continue + } + if containsString(pathKeys, key) { + continue + } + switch operationParameterLocation(op, key) { + case "header", "cookie", "query": + continue + } + if _, exists := queryKeys[key]; exists { + continue + } + payload[key] = value + } + if len(payload) == 0 { + return nil, nil + } + return encodeJSONBody(payload) +} + +func applyRequestParameters(req *http.Request, op openAPIOperation, params map[string]any) { + explicitHeaders, hasExplicitHeaders := nestedMap(params, "header") + explicitCookies, hasExplicitCookies := nestedMap(params, "cookie") + + if hasExplicitHeaders { + applyHeaderValues(req.Header, explicitHeaders) + } + + applyTopLevelHeaderParameters(req.Header, op, params, explicitHeaders, hasExplicitHeaders) + + if hasExplicitCookies { + applyCookieValues(req, explicitCookies) + } + applyTopLevelCookieParameters(req, op, params, explicitCookies, hasExplicitCookies) +} + +func applyTopLevelHeaderParameters(headers http.Header, op openAPIOperation, params, explicit map[string]any, hasExplicit bool) { + for key, value := range params { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { + continue + } + if operationParameterLocation(op, key) != "header" { + continue + } + if hasExplicit { + if _, ok := explicit[key]; ok { + continue + } + } + applyHeaderValue(headers, key, value) + } +} + +func applyTopLevelCookieParameters(req *http.Request, op openAPIOperation, params, explicit map[string]any, hasExplicit bool) { + for key, value := range params { + if key == "path" || key == "query" || key == "body" || key == "header" || key == "cookie" { + continue + } + if operationParameterLocation(op, key) != "cookie" { + continue + } + if hasExplicit { + if _, ok := explicit[key]; ok { + continue + } + } + applyCookieValue(req, key, value) + } +} + +func applyHeaderValues(headers http.Header, values map[string]any) { + for key, value := range values { + applyHeaderValue(headers, key, value) + } +} + +func applyHeaderValue(headers http.Header, key string, value any) { + switch v := value.(type) { + case nil: + return + case []string: + for _, item := range v { + headers.Add(key, item) + } + return + case []any: + for _, item := range v { + headers.Add(key, fmt.Sprint(item)) + } + return + } + + rv := reflect.ValueOf(value) + if rv.IsValid() && (rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array) && !(rv.Type().Elem().Kind() == reflect.Uint8) { + for i := 0; i < rv.Len(); i++ { + headers.Add(key, fmt.Sprint(rv.Index(i).Interface())) + } + return + } + + headers.Set(key, fmt.Sprint(value)) +} + +func applyCookieValues(req *http.Request, values map[string]any) { + for key, value := range values { + applyCookieValue(req, key, value) + } +} + +func applyCookieValue(req *http.Request, key string, value any) { + switch v := value.(type) { + case nil: + return + case []string: + for _, item := range v { + req.AddCookie(&http.Cookie{Name: key, Value: item}) + } + return + case []any: + for _, item := range v { + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(item)}) + } + return + } + + rv := reflect.ValueOf(value) + if rv.IsValid() && (rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array) && !(rv.Type().Elem().Kind() == reflect.Uint8) { + for i := 0; i < rv.Len(); i++ { + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(rv.Index(i).Interface())}) + } + return + } + + req.AddCookie(&http.Cookie{Name: key, Value: fmt.Sprint(value)}) +} + +func parseOperationParameters(operation map[string]any) []openAPIParameter { + rawParams, ok := operation["parameters"].([]any) + if !ok { + return nil + } + + params := make([]openAPIParameter, 0, len(rawParams)) + for _, rawParam := range rawParams { + param, ok := rawParam.(map[string]any) + if !ok { + continue + } + name, _ := param["name"].(string) + in, _ := param["in"].(string) + if name == "" || in == "" { + continue + } + required, _ := param["required"].(bool) + schema, _ := param["schema"].(map[string]any) + params = append(params, openAPIParameter{name: name, in: in, required: required, schema: schema}) + } + + return params +} + +func operationParameterLocation(op openAPIOperation, name string) string { + for _, param := range op.parameters { + if param.name == name { + return param.in + } + } + return "" +} + +func validateParameterValues(op openAPIOperation, params map[string]any) error { + for _, param := range op.parameters { + if len(param.schema) == 0 { + continue + } + + if nested, ok := nestedMap(params, param.in); ok { + if value, exists := nested[param.name]; exists { + if err := validateParameterValue(param, value); err != nil { + return err + } + continue + } + } + + if value, exists := params[param.name]; exists { + if err := validateParameterValue(param, value); err != nil { + return err + } + } + } + return nil +} + +func validateParameterValue(param openAPIParameter, value any) error { + if value == nil { + return nil + } + + data, err := json.Marshal(value) + if err != nil { + return coreerr.E("OpenAPIClient.validateParameterValue", fmt.Sprintf("marshal %s parameter %q", param.in, param.name), err) + } + if err := validateOpenAPISchema(data, param.schema, fmt.Sprintf("%s parameter %q", param.in, param.name)); err != nil { + return err + } + return nil +} + +func validateRequiredParameters(op openAPIOperation, params map[string]any, pathKeys []string) error { + for _, param := range op.parameters { + if !param.required { + continue + } + if parameterProvided(params, param.name, param.in) { + continue + } + return coreerr.E("OpenAPIClient.buildURL", fmt.Sprintf("missing required %s parameter %q", param.in, param.name), nil) + } + return nil +} + +func parameterProvided(params map[string]any, name, location string) bool { + if nested, ok := nestedMap(params, location); ok { + if value, exists := nested[name]; exists && value != nil { + return true + } + } + + if value, exists := params[name]; exists { + if value != nil { + return true + } + } + + return false +} + +func encodeJSONBody(v any) ([]byte, error) { + data, err := json.Marshal(v) + if err != nil { + return nil, err + } + return data, nil +} + +func normaliseParams(params any) (map[string]any, error) { + if params == nil { + return map[string]any{}, nil + } + + if m, ok := params.(map[string]any); ok { + return m, nil + } + + data, err := json.Marshal(params) + if err != nil { + return nil, coreerr.E("OpenAPIClient.normaliseParams", "marshal params", err) + } + + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + return nil, coreerr.E("OpenAPIClient.normaliseParams", "decode params", err) + } + + return out, nil +} + +func nestedMap(params map[string]any, key string) (map[string]any, bool) { + raw, ok := params[key] + if !ok { + return nil, false + } + + m, ok := raw.(map[string]any) + if ok { + return m, true + } + + data, err := json.Marshal(raw) + if err != nil { + return nil, false + } + if err := json.Unmarshal(data, &m); err != nil { + return nil, false + } + return m, true +} + +func pathParameterNames(pathTemplate string) []string { + var names []string + for i := 0; i < len(pathTemplate); i++ { + if pathTemplate[i] != '{' { + continue + } + end := strings.IndexByte(pathTemplate[i+1:], '}') + if end < 0 { + break + } + name := pathTemplate[i+1 : i+1+end] + if name != "" { + names = append(names, name) + } + i += end + 1 + } + return names +} + +func containsString(values []string, target string) bool { + for _, value := range values { + if value == target { + return true + } + } + return false +} + +func appendQueryValue(query url.Values, key string, value any) { + switch v := value.(type) { + case nil: + return + case []byte: + query.Add(key, string(v)) + return + case []string: + for _, item := range v { + query.Add(key, item) + } + return + case []any: + for _, item := range v { + appendQueryValue(query, key, item) + } + return + } + + rv := reflect.ValueOf(value) + if !rv.IsValid() { + return + } + + switch rv.Kind() { + case reflect.Slice, reflect.Array: + if rv.Type().Elem().Kind() == reflect.Uint8 { + query.Add(key, string(rv.Bytes())) + return + } + for i := 0; i < rv.Len(); i++ { + appendQueryValue(query, key, rv.Index(i).Interface()) + } + return + } + + query.Add(key, fmt.Sprint(value)) +} + +func isAbsoluteBaseURL(raw string) bool { + u, err := url.Parse(raw) + if err != nil { + return false + } + return u.Scheme != "" && u.Host != "" +} + +func requestBodySchema(operation map[string]any) map[string]any { + rawRequestBody, ok := operation["requestBody"].(map[string]any) + if !ok { + return nil + } + + content, ok := rawRequestBody["content"].(map[string]any) + if !ok { + return nil + } + + rawJSON, ok := content["application/json"].(map[string]any) + if !ok { + return nil + } + + schema, _ := rawJSON["schema"].(map[string]any) + return schema +} + +func firstSuccessResponseSchema(operation map[string]any) map[string]any { + responses, ok := operation["responses"].(map[string]any) + if !ok { + return nil + } + + for _, code := range []string{"200", "201", "202", "203", "204", "205", "206", "207", "208", "226"} { + rawResp, ok := responses[code].(map[string]any) + if !ok { + continue + } + content, ok := rawResp["content"].(map[string]any) + if !ok { + continue + } + rawJSON, ok := content["application/json"].(map[string]any) + if !ok { + continue + } + schema, _ := rawJSON["schema"].(map[string]any) + if len(schema) > 0 { + return schema + } + } + + return nil +} + +func validateOpenAPISchema(body []byte, schema map[string]any, label string) error { + if len(bytes.TrimSpace(body)) == 0 { + return nil + } + + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + if err := dec.Decode(&payload); err != nil { + return coreerr.E("OpenAPIClient.validateOpenAPISchema", fmt.Sprintf("validate %s: invalid JSON", label), err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return coreerr.E("OpenAPIClient.validateOpenAPISchema", fmt.Sprintf("validate %s: expected a single JSON value", label), nil) + } + + if err := validateSchemaNode(payload, schema, ""); err != nil { + return coreerr.E("OpenAPIClient.validateOpenAPISchema", fmt.Sprintf("validate %s", label), err) + } + + return nil +} + +func validateOpenAPIResponse(payload []byte, schema map[string]any, operationID string) error { + var decoded any + dec := json.NewDecoder(bytes.NewReader(payload)) + dec.UseNumber() + if err := dec.Decode(&decoded); err != nil { + return coreerr.E("OpenAPIClient.validateOpenAPIResponse", fmt.Sprintf("openapi call %s returned invalid JSON", operationID), err) + } + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return coreerr.E("OpenAPIClient.validateOpenAPIResponse", fmt.Sprintf("openapi call %s returned multiple JSON values", operationID), nil) + } + + if err := validateSchemaNode(decoded, schema, ""); err != nil { + return coreerr.E("OpenAPIClient.validateOpenAPIResponse", fmt.Sprintf("openapi call %s response does not match spec", operationID), err) + } + + return nil +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..b0b2530 --- /dev/null +++ b/client_test.go @@ -0,0 +1,963 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "slices" + + api "dappco.re/go/core/api" +) + +func TestOpenAPIClient_Good_CallOperationByID(t *testing.T) { + errCh := make(chan error, 2) + mux := http.NewServeMux() + mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query().Get("name"); got != "Ada" { + errCh <- fmt.Errorf("expected query name=Ada, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"message":"hello"}}`)) + }) + mux.HandleFunc("/users/123", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errCh <- fmt.Errorf("expected POST, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query().Get("verbose"); got != "true" { + errCh <- fmt.Errorf("expected query verbose=true, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"id":"123","name":"Ada"}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /hello: + get: + operationId: get_hello + /users/{id}: + post: + operationId: update_user + requestBody: + required: true + content: + application/json: + schema: + type: object +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("get_hello", map[string]any{ + "name": "Ada", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + hello, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if hello["message"] != "hello" { + t.Fatalf("expected message=hello, got %#v", hello["message"]) + } + + result, err = client.Call("update_user", map[string]any{ + "path": map[string]any{ + "id": "123", + }, + "query": map[string]any{ + "verbose": true, + }, + "body": map[string]any{ + "name": "Ada", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + updated, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if updated["id"] != "123" { + t.Fatalf("expected id=123, got %#v", updated["id"]) + } + if updated["name"] != "Ada" { + t.Fatalf("expected name=Ada, got %#v", updated["name"]) + } +} + +func TestOpenAPIClient_Good_LoadsSpecFromReader(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"message":"pong"}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + client := api.NewOpenAPIClient( + api.WithSpecReader(strings.NewReader(`openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /ping: + get: + operationId: ping +`)), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("ping", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + ping, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if ping["message"] != "pong" { + t.Fatalf("expected message=pong, got %#v", ping["message"]) + } +} + +func TestOpenAPIClient_Good_ExposesOperationSnapshots(t *testing.T) { + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com +paths: + /users/{id}: + post: + operationId: update_user + parameters: + - name: id + in: path + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + type: object +`) + + client := api.NewOpenAPIClient(api.WithSpec(specPath)) + + operations, err := client.Operations() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(operations) != 1 { + t.Fatalf("expected 1 operation, got %d", len(operations)) + } + + op := operations[0] + if op.OperationID != "update_user" { + t.Fatalf("expected operationId update_user, got %q", op.OperationID) + } + if op.Method != http.MethodPost { + t.Fatalf("expected method POST, got %q", op.Method) + } + if op.PathTemplate != "/users/{id}" { + t.Fatalf("expected path template /users/{id}, got %q", op.PathTemplate) + } + if !op.HasRequestBody { + t.Fatal("expected operation to report a request body") + } + if len(op.Parameters) != 1 || op.Parameters[0].Name != "id" { + t.Fatalf("expected one path parameter snapshot, got %+v", op.Parameters) + } + + op.Parameters[0].Schema["type"] = "integer" + operations[0].PathTemplate = "/mutated" + + again, err := client.Operations() + if err != nil { + t.Fatalf("unexpected error on re-read: %v", err) + } + if again[0].PathTemplate != "/users/{id}" { + t.Fatalf("expected snapshot to remain immutable, got %q", again[0].PathTemplate) + } + if got := again[0].Parameters[0].Schema["type"]; got != "string" { + t.Fatalf("expected cloned parameter schema, got %#v", got) + } +} + +func TestOpenAPIClient_Good_ExposesServerSnapshots(t *testing.T) { + client := api.NewOpenAPIClient(api.WithSpecReader(strings.NewReader(`openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com + - url: /relative +paths: {} +`))) + + servers, err := client.Servers() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !slices.Equal(servers, []string{"https://api.example.com", "/relative"}) { + t.Fatalf("expected server snapshot to preserve order, got %v", servers) + } + + servers[0] = "https://mutated.example.com" + again, err := client.Servers() + if err != nil { + t.Fatalf("unexpected error on re-read: %v", err) + } + if !slices.Equal(again, []string{"https://api.example.com", "/relative"}) { + t.Fatalf("expected server snapshot to be cloned, got %v", again) + } +} + +func TestOpenAPIClient_Good_IteratorsExposeSnapshots(t *testing.T) { + client := api.NewOpenAPIClient(api.WithSpecReader(strings.NewReader(`openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: https://api.example.com +paths: + /users/{id}: + post: + operationId: update_user + parameters: + - name: id + in: path + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + type: object +`))) + + operations, err := client.OperationsIter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var operationIDs []string + for op := range operations { + operationIDs = append(operationIDs, op.OperationID) + } + if !slices.Equal(operationIDs, []string{"update_user"}) { + t.Fatalf("expected iterator to preserve operation snapshots, got %v", operationIDs) + } + + servers, err := client.ServersIter() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var serverURLs []string + for server := range servers { + serverURLs = append(serverURLs, server) + } + if !slices.Equal(serverURLs, []string{"https://api.example.com"}) { + t.Fatalf("expected iterator to preserve server snapshots, got %v", serverURLs) + } +} + +func TestOpenAPIClient_Good_CallHeadOperationWithRequestBody(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/head", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodHead { + errCh <- fmt.Errorf("expected HEAD, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.RawQuery; got != "" { + errCh <- fmt.Errorf("expected no query string, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + errCh <- fmt.Errorf("read body: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"name":"Ada"}` { + errCh <- fmt.Errorf("expected JSON body {\"name\":\"Ada\"}, got %q", string(body)) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /head: + head: + operationId: head_check + requestBody: + required: true + content: + application/json: + schema: + type: object +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("head_check", map[string]any{ + "name": "Ada", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + if result != nil { + t.Fatalf("expected nil result for empty HEAD response body, got %T", result) + } +} + +func TestOpenAPIClient_Good_CallOperationWithRepeatedQueryValues(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query()["tag"]; len(got) != 2 || got[0] != "alpha" || got[1] != "beta" { + errCh <- fmt.Errorf("expected repeated tag values [alpha beta], got %v", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query().Get("page"); got != "2" { + errCh <- fmt.Errorf("expected page=2, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /search: + get: + operationId: search_items +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("search_items", map[string]any{ + "tag": []string{"alpha", "beta"}, + "page": 2, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + decoded, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if okValue, ok := decoded["ok"].(bool); !ok || !okValue { + t.Fatalf("expected ok=true, got %#v", decoded["ok"]) + } +} + +func TestOpenAPIClient_Good_UsesTopLevelQueryParametersOnPost(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/submit", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + errCh <- fmt.Errorf("expected POST, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.URL.Query().Get("verbose"); got != "true" { + errCh <- fmt.Errorf("expected query verbose=true, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + errCh <- fmt.Errorf("read body: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if string(body) != `{"name":"Ada"}` { + errCh <- fmt.Errorf("expected JSON body {\"name\":\"Ada\"}, got %q", string(body)) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /submit: + post: + operationId: submit_item + requestBody: + required: true + content: + application/json: + schema: + type: object + parameters: + - name: verbose + in: query +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("submit_item", map[string]any{ + "verbose": true, + "name": "Ada", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + decoded, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if okValue, ok := decoded["ok"].(bool); !ok || !okValue { + t.Fatalf("expected ok=true, got %#v", decoded["ok"]) + } +} + +func TestOpenAPIClient_Bad_MissingRequiredQueryParameter(t *testing.T) { + called := make(chan struct{}, 1) + mux := http.NewServeMux() + mux.HandleFunc("/submit", func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /submit: + post: + operationId: submit_item + parameters: + - name: verbose + in: query + required: true +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("submit_item", map[string]any{ + "name": "Ada", + }); err == nil { + t.Fatal("expected required query parameter validation error, got nil") + } + + select { + case <-called: + t.Fatal("expected validation to fail before the HTTP call") + default: + } +} + +func TestOpenAPIClient_Bad_ValidatesQueryParameterAgainstSchema(t *testing.T) { + called := make(chan struct{}, 1) + mux := http.NewServeMux() + mux.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /search: + get: + operationId: search_items + parameters: + - name: page + in: query + schema: + type: integer +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("search_items", map[string]any{ + "page": "two", + }); err == nil { + t.Fatal("expected query parameter validation error, got nil") + } + + select { + case <-called: + t.Fatal("expected validation to fail before the HTTP call") + default: + } +} + +func TestOpenAPIClient_Bad_ValidatesPathParameterAgainstSchema(t *testing.T) { + called := make(chan struct{}, 1) + mux := http.NewServeMux() + mux.HandleFunc("/users/123", func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users/{id}: + get: + operationId: get_user + parameters: + - name: id + in: path + required: true + schema: + type: integer +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("get_user", map[string]any{ + "path": map[string]any{ + "id": "abc", + }, + }); err == nil { + t.Fatal("expected path parameter validation error, got nil") + } + + select { + case <-called: + t.Fatal("expected validation to fail before the HTTP call") + default: + } +} + +func TestOpenAPIClient_Good_UsesHeaderAndCookieParameters(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/inspect", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.Header.Get("X-Trace-ID"); got != "trace-123" { + errCh <- fmt.Errorf("expected X-Trace-ID=trace-123, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + if got := r.Header.Get("X-Custom-Header"); got != "custom-value" { + errCh <- fmt.Errorf("expected X-Custom-Header=custom-value, got %q", got) + w.WriteHeader(http.StatusInternalServerError) + return + } + session, err := r.Cookie("session_id") + if err != nil { + errCh <- fmt.Errorf("expected session_id cookie: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if session.Value != "cookie-123" { + errCh <- fmt.Errorf("expected session_id=cookie-123, got %q", session.Value) + w.WriteHeader(http.StatusInternalServerError) + return + } + pref, err := r.Cookie("pref") + if err != nil { + errCh <- fmt.Errorf("expected pref cookie: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if pref.Value != "dark" { + errCh <- fmt.Errorf("expected pref=dark, got %q", pref.Value) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"ok":true}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /inspect: + get: + operationId: inspect_request + parameters: + - name: X-Trace-ID + in: header + - name: session_id + in: cookie +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + result, err := client.Call("inspect_request", map[string]any{ + "X-Trace-ID": "trace-123", + "session_id": "cookie-123", + "header": map[string]any{ + "X-Custom-Header": "custom-value", + }, + "cookie": map[string]any{ + "pref": "dark", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + decoded, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if okValue, ok := decoded["ok"].(bool); !ok || !okValue { + t.Fatalf("expected ok=true, got %#v", decoded["ok"]) + } +} + +func TestOpenAPIClient_Good_UsesFirstAbsoluteServer(t *testing.T) { + errCh := make(chan error, 1) + mux := http.NewServeMux() + mux.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + errCh <- fmt.Errorf("expected GET, got %s", r.Method) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"message":"hello"}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +servers: + - url: " `+srv.URL+` " + - url: / + - url: " `+srv.URL+` " +paths: + /hello: + get: + operationId: get_hello +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + ) + + result, err := client.Call("get_hello", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + select { + case err := <-errCh: + t.Fatal(err) + default: + } + + hello, ok := result.(map[string]any) + if !ok { + t.Fatalf("expected map result, got %T", result) + } + if hello["message"] != "hello" { + t.Fatalf("expected message=hello, got %#v", hello["message"]) + } +} + +func TestOpenAPIClient_Bad_ValidatesRequestBodyAgainstSchema(t *testing.T) { + called := make(chan struct{}, 1) + mux := http.NewServeMux() + mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) { + called <- struct{}{} + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"id":"123"}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users: + post: + operationId: create_user + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name] + properties: + name: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + type: object + properties: + success: + type: boolean + data: + type: object + properties: + id: + type: string +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("create_user", map[string]any{ + "body": map[string]any{}, + }); err == nil { + t.Fatal("expected request body validation error, got nil") + } + + select { + case <-called: + t.Fatal("expected request validation to fail before the HTTP call") + default: + } +} + +func TestOpenAPIClient_Bad_ValidatesResponseAgainstSchema(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/users", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"success":true,"data":{"id":123}}`)) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /users: + get: + operationId: list_users + responses: + "200": + description: OK + content: + application/json: + schema: + type: object + required: [success, data] + properties: + success: + type: boolean + data: + type: object + required: [id] + properties: + id: + type: string +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL(srv.URL), + ) + + if _, err := client.Call("list_users", nil); err == nil { + t.Fatal("expected response validation error, got nil") + } +} + +func TestOpenAPIClient_Bad_MissingOperation(t *testing.T) { + specPath := writeTempSpec(t, `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: {} +`) + + client := api.NewOpenAPIClient( + api.WithSpec(specPath), + api.WithBaseURL("http://example.invalid"), + ) + + if _, err := client.Call("missing", nil); err == nil { + t.Fatal("expected error for missing operation, got nil") + } +} + +func writeTempSpec(t *testing.T, contents string) string { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "openapi.yaml") + if err := os.WriteFile(path, []byte(contents), 0o600); err != nil { + t.Fatalf("write spec: %v", err) + } + return path +} diff --git a/cmd/api/cmd.go b/cmd/api/cmd.go index e0fb419..0dd63cc 100644 --- a/cmd/api/cmd.go +++ b/cmd/api/cmd.go @@ -8,7 +8,12 @@ func init() { cli.RegisterCommands(AddAPICommands) } -// AddAPICommands registers the 'api' command group. +// AddAPICommands registers the `api` command group. +// +// Example: +// +// root := &cli.Command{Use: "root"} +// api.AddAPICommands(root) func AddAPICommands(root *cli.Command) { apiCmd := cli.NewGroup("api", "API specification and SDK generation", "") root.AddCommand(apiCmd) diff --git a/cmd/api/cmd_args.go b/cmd/api/cmd_args.go new file mode 100644 index 0000000..042529a --- /dev/null +++ b/cmd/api/cmd_args.go @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import "strings" + +// splitUniqueCSV trims and deduplicates a comma-separated list while +// preserving the first occurrence of each value. +func splitUniqueCSV(raw string) []string { + if raw == "" { + return nil + } + + parts := strings.Split(raw, ",") + values := make([]string, 0, len(parts)) + seen := make(map[string]struct{}, len(parts)) + + for _, part := range parts { + value := strings.TrimSpace(part) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + values = append(values, value) + } + + return values +} + +// normalisePublicPaths trims whitespace, ensures a leading slash, and removes +// duplicate entries while preserving the first occurrence of each path. +func normalisePublicPaths(paths []string) []string { + if len(paths) == 0 { + return nil + } + + out := make([]string, 0, len(paths)) + seen := make(map[string]struct{}, len(paths)) + + for _, path := range paths { + path = strings.TrimSpace(path) + if path == "" { + continue + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + path = strings.TrimRight(path, "/") + if path == "" { + path = "/" + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + out = append(out, path) + } + + if len(out) == 0 { + return nil + } + + return out +} diff --git a/cmd/api/cmd_sdk.go b/cmd/api/cmd_sdk.go index be5c9ee..dfd5a38 100644 --- a/cmd/api/cmd_sdk.go +++ b/cmd/api/cmd_sdk.go @@ -3,8 +3,8 @@ package api import ( - "context" "fmt" + "iter" "os" "strings" @@ -16,46 +16,32 @@ import ( goapi "dappco.re/go/core/api" ) +const ( + defaultSDKTitle = "Lethean Core API" + defaultSDKDescription = "Lethean Core API" + defaultSDKVersion = "1.0.0" +) + func addSDKCommand(parent *cli.Command) { var ( lang string output string specFile string packageName string + cfg specBuilderConfig ) - cmd := cli.NewCommand("sdk", "Generate client SDKs from OpenAPI spec", "", func(cmd *cli.Command, args []string) error { - if lang == "" { - return coreerr.E("sdk.Generate", "--lang is required. Supported: "+strings.Join(goapi.SupportedLanguages(), ", "), nil) - } - - // If no spec file provided, generate one to a temp file. - if specFile == "" { - builder := &goapi.SpecBuilder{ - Title: "Lethean Core API", - Description: "Lethean Core API", - Version: "1.0.0", - } - - bridge := goapi.NewToolBridge("/tools") - groups := []goapi.RouteGroup{bridge} - - tmpFile, err := os.CreateTemp("", "openapi-*.json") - if err != nil { - return coreerr.E("sdk.Generate", "create temp spec file", err) - } - defer coreio.Local.Delete(tmpFile.Name()) + cfg.title = defaultSDKTitle + cfg.description = defaultSDKDescription + cfg.version = defaultSDKVersion - if err := goapi.ExportSpec(tmpFile, "json", builder, groups); err != nil { - tmpFile.Close() - return coreerr.E("sdk.Generate", "generate spec", err) - } - tmpFile.Close() - specFile = tmpFile.Name() + cmd := cli.NewCommand("sdk", "Generate client SDKs from OpenAPI spec", "", func(cmd *cli.Command, args []string) error { + languages := splitUniqueCSV(lang) + if len(languages) == 0 { + return coreerr.E("sdk.Generate", "--lang is required and must include at least one non-empty language. Supported: "+strings.Join(goapi.SupportedLanguages(), ", "), nil) } gen := &goapi.SDKGenerator{ - SpecPath: specFile, OutputDir: output, PackageName: packageName, } @@ -67,14 +53,39 @@ func addSDKCommand(parent *cli.Command) { return coreerr.E("sdk.Generate", "openapi-generator-cli not installed", nil) } - // Generate for each language. - for l := range strings.SplitSeq(lang, ",") { - l = strings.TrimSpace(l) - if l == "" { - continue + // If no spec file was provided, generate one only after confirming the + // generator is available. + resolvedSpecFile := specFile + if resolvedSpecFile == "" { + builder, err := sdkSpecBuilder(cfg) + if err != nil { + return err } + groups := sdkSpecGroupsIter() + + tmpFile, err := os.CreateTemp("", "openapi-*.json") + if err != nil { + return coreerr.E("sdk.Generate", "create temp spec file", err) + } + tmpPath := tmpFile.Name() + if err := tmpFile.Close(); err != nil { + _ = coreio.Local.Delete(tmpPath) + return coreerr.E("sdk.Generate", "close temp spec file", err) + } + defer coreio.Local.Delete(tmpPath) + + if err := goapi.ExportSpecToFileIter(tmpPath, "json", builder, groups); err != nil { + return coreerr.E("sdk.Generate", "generate spec", err) + } + resolvedSpecFile = tmpPath + } + + gen.SpecPath = resolvedSpecFile + + // Generate for each language. + for _, l := range languages { fmt.Fprintf(os.Stderr, "Generating %s SDK...\n", l) - if err := gen.Generate(context.Background(), l); err != nil { + if err := gen.Generate(cli.Context(), l); err != nil { return coreerr.E("sdk.Generate", "generate "+l, err) } fmt.Fprintf(os.Stderr, " Done: %s/%s/\n", output, l) @@ -85,8 +96,50 @@ func addSDKCommand(parent *cli.Command) { cli.StringFlag(cmd, &lang, "lang", "l", "", "Target language(s), comma-separated (e.g. go,python,typescript-fetch)") cli.StringFlag(cmd, &output, "output", "o", "./sdk", "Output directory for generated SDKs") - cli.StringFlag(cmd, &specFile, "spec", "s", "", "Path to existing OpenAPI spec (generates from MCP tools if not provided)") + cli.StringFlag(cmd, &specFile, "spec", "s", "", "Path to an existing OpenAPI spec (generates a temporary spec from registered route groups and the built-in tool bridge if not provided)") cli.StringFlag(cmd, &packageName, "package", "p", "lethean", "Package name for generated SDK") + registerSpecBuilderFlags(cmd, &cfg) parent.AddCommand(cmd) } + +func sdkSpecBuilder(cfg specBuilderConfig) (*goapi.SpecBuilder, error) { + return newSpecBuilder(specBuilderConfig{ + title: cfg.title, + summary: cfg.summary, + description: cfg.description, + version: cfg.version, + swaggerPath: cfg.swaggerPath, + graphqlPath: cfg.graphqlPath, + graphqlPlayground: cfg.graphqlPlayground, + graphqlPlaygroundPath: cfg.graphqlPlaygroundPath, + ssePath: cfg.ssePath, + wsPath: cfg.wsPath, + pprofEnabled: cfg.pprofEnabled, + expvarEnabled: cfg.expvarEnabled, + cacheEnabled: cfg.cacheEnabled, + cacheTTL: cfg.cacheTTL, + cacheMaxEntries: cfg.cacheMaxEntries, + cacheMaxBytes: cfg.cacheMaxBytes, + i18nDefaultLocale: cfg.i18nDefaultLocale, + i18nSupportedLocales: cfg.i18nSupportedLocales, + authentikIssuer: cfg.authentikIssuer, + authentikClientID: cfg.authentikClientID, + authentikTrustedProxy: cfg.authentikTrustedProxy, + authentikPublicPaths: cfg.authentikPublicPaths, + termsURL: cfg.termsURL, + contactName: cfg.contactName, + contactURL: cfg.contactURL, + contactEmail: cfg.contactEmail, + licenseName: cfg.licenseName, + licenseURL: cfg.licenseURL, + externalDocsDescription: cfg.externalDocsDescription, + externalDocsURL: cfg.externalDocsURL, + servers: cfg.servers, + securitySchemes: cfg.securitySchemes, + }) +} + +func sdkSpecGroupsIter() iter.Seq[goapi.RouteGroup] { + return specGroupsIter(goapi.NewToolBridge("/tools")) +} diff --git a/cmd/api/cmd_spec.go b/cmd/api/cmd_spec.go index 7ad145e..57d6af6 100644 --- a/cmd/api/cmd_spec.go +++ b/cmd/api/cmd_spec.go @@ -3,8 +3,10 @@ package api import ( + "encoding/json" "fmt" "os" + "strings" "forge.lthn.ai/core/cli/pkg/cli" @@ -13,42 +15,91 @@ import ( func addSpecCommand(parent *cli.Command) { var ( - output string - format string - title string - version string + output string + format string + cfg specBuilderConfig ) + cfg.title = "Lethean Core API" + cfg.description = "Lethean Core API" + cfg.version = "1.0.0" + cmd := cli.NewCommand("spec", "Generate OpenAPI specification", "", func(cmd *cli.Command, args []string) error { - // Build spec from registered route groups. - // Additional groups can be added here as the platform grows. - builder := &goapi.SpecBuilder{ - Title: title, - Description: "Lethean Core API", - Version: version, + // Build spec from all route groups registered for CLI generation. + builder, err := newSpecBuilder(cfg) + if err != nil { + return err } - // Start with the default tool bridge — future versions will - // auto-populate from the MCP tool registry once the bridge - // integration lands in the local go-ai module. bridge := goapi.NewToolBridge("/tools") - groups := []goapi.RouteGroup{bridge} + groups := specGroupsIter(bridge) if output != "" { - if err := goapi.ExportSpecToFile(output, format, builder, groups); err != nil { + if err := goapi.ExportSpecToFileIter(output, format, builder, groups); err != nil { return err } fmt.Fprintf(os.Stderr, "Spec written to %s\n", output) return nil } - return goapi.ExportSpec(os.Stdout, format, builder, groups) + return goapi.ExportSpecIter(os.Stdout, format, builder, groups) }) cli.StringFlag(cmd, &output, "output", "o", "", "Write spec to file instead of stdout") cli.StringFlag(cmd, &format, "format", "f", "json", "Output format: json or yaml") - cli.StringFlag(cmd, &title, "title", "t", "Lethean Core API", "API title in spec") - cli.StringFlag(cmd, &version, "version", "V", "1.0.0", "API version in spec") + registerSpecBuilderFlags(cmd, &cfg) parent.AddCommand(cmd) } + +func parseServers(raw string) []string { + return splitUniqueCSV(raw) +} + +func parseSecuritySchemes(raw string) (map[string]any, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, nil + } + + var schemes map[string]any + if err := json.Unmarshal([]byte(raw), &schemes); err != nil { + return nil, cli.Err("invalid security schemes JSON: %w", err) + } + return schemes, nil +} + +func registerSpecBuilderFlags(cmd *cli.Command, cfg *specBuilderConfig) { + cli.StringFlag(cmd, &cfg.title, "title", "t", cfg.title, "API title in spec") + cli.StringFlag(cmd, &cfg.summary, "summary", "", cfg.summary, "OpenAPI info summary in spec") + cli.StringFlag(cmd, &cfg.description, "description", "d", cfg.description, "API description in spec") + cli.StringFlag(cmd, &cfg.version, "version", "V", cfg.version, "API version in spec") + cli.StringFlag(cmd, &cfg.swaggerPath, "swagger-path", "", "", "Swagger UI path in generated spec") + cli.StringFlag(cmd, &cfg.graphqlPath, "graphql-path", "", "", "GraphQL endpoint path in generated spec") + cli.BoolFlag(cmd, &cfg.graphqlPlayground, "graphql-playground", "", false, "Include the GraphQL playground endpoint in generated spec") + cli.StringFlag(cmd, &cfg.graphqlPlaygroundPath, "graphql-playground-path", "", "", "GraphQL playground path in generated spec") + cli.StringFlag(cmd, &cfg.ssePath, "sse-path", "", "", "SSE endpoint path in generated spec") + cli.StringFlag(cmd, &cfg.wsPath, "ws-path", "", "", "WebSocket endpoint path in generated spec") + cli.BoolFlag(cmd, &cfg.pprofEnabled, "pprof", "", false, "Include pprof endpoints in generated spec") + cli.BoolFlag(cmd, &cfg.expvarEnabled, "expvar", "", false, "Include expvar endpoint in generated spec") + cli.BoolFlag(cmd, &cfg.cacheEnabled, "cache", "", false, "Include cache metadata in generated spec") + cli.StringFlag(cmd, &cfg.cacheTTL, "cache-ttl", "", "", "Cache TTL in generated spec") + cli.IntFlag(cmd, &cfg.cacheMaxEntries, "cache-max-entries", "", 0, "Cache max entries in generated spec") + cli.IntFlag(cmd, &cfg.cacheMaxBytes, "cache-max-bytes", "", 0, "Cache max bytes in generated spec") + cli.StringFlag(cmd, &cfg.i18nDefaultLocale, "i18n-default-locale", "", "", "Default locale in generated spec") + cli.StringFlag(cmd, &cfg.i18nSupportedLocales, "i18n-supported-locales", "", "", "Comma-separated supported locales in generated spec") + cli.StringFlag(cmd, &cfg.authentikIssuer, "authentik-issuer", "", "", "Authentik issuer URL in generated spec") + cli.StringFlag(cmd, &cfg.authentikClientID, "authentik-client-id", "", "", "Authentik client ID in generated spec") + cli.BoolFlag(cmd, &cfg.authentikTrustedProxy, "authentik-trusted-proxy", "", false, "Mark Authentik proxy headers as trusted in generated spec") + cli.StringFlag(cmd, &cfg.authentikPublicPaths, "authentik-public-paths", "", "", "Comma-separated public paths in generated spec") + cli.StringFlag(cmd, &cfg.termsURL, "terms-of-service", "", "", "OpenAPI terms of service URL in spec") + cli.StringFlag(cmd, &cfg.contactName, "contact-name", "", "", "OpenAPI contact name in spec") + cli.StringFlag(cmd, &cfg.contactURL, "contact-url", "", "", "OpenAPI contact URL in spec") + cli.StringFlag(cmd, &cfg.contactEmail, "contact-email", "", "", "OpenAPI contact email in spec") + cli.StringFlag(cmd, &cfg.licenseName, "license-name", "", "", "OpenAPI licence name in spec") + cli.StringFlag(cmd, &cfg.licenseURL, "license-url", "", "", "OpenAPI licence URL in spec") + cli.StringFlag(cmd, &cfg.externalDocsDescription, "external-docs-description", "", "", "OpenAPI external documentation description in spec") + cli.StringFlag(cmd, &cfg.externalDocsURL, "external-docs-url", "", "", "OpenAPI external documentation URL in spec") + cli.StringFlag(cmd, &cfg.servers, "server", "S", "", "Comma-separated OpenAPI server URL(s)") + cli.StringFlag(cmd, &cfg.securitySchemes, "security-schemes", "", "", "JSON object of custom OpenAPI security schemes") +} diff --git a/cmd/api/cmd_test.go b/cmd/api/cmd_test.go index b24c723..09361b7 100644 --- a/cmd/api/cmd_test.go +++ b/cmd/api/cmd_test.go @@ -4,11 +4,45 @@ package api import ( "bytes" + "encoding/json" + "iter" + "os" "testing" + "github.com/gin-gonic/gin" + "forge.lthn.ai/core/cli/pkg/cli" + + api "dappco.re/go/core/api" ) +type specCmdStubGroup struct{} + +func (specCmdStubGroup) Name() string { return "registered" } +func (specCmdStubGroup) BasePath() string { return "/registered" } +func (specCmdStubGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (specCmdStubGroup) Describe() []api.RouteDescription { + return []api.RouteDescription{ + { + Method: "GET", + Path: "/ping", + Summary: "Ping registered group", + Tags: []string{"registered"}, + Response: map[string]any{ + "type": "string", + }, + }, + } +} + +func collectRouteGroups(groups iter.Seq[api.RouteGroup]) []api.RouteGroup { + out := make([]api.RouteGroup, 0) + for group := range groups { + out = append(out, group) + } + return out +} + func TestAPISpecCmd_Good_CommandStructure(t *testing.T) { root := &cli.Command{Use: "root"} AddAPICommands(root) @@ -51,9 +85,991 @@ func TestAPISpecCmd_Good_JSON(t *testing.T) { if specCmd.Flag("title") == nil { t.Fatal("expected --title flag on spec command") } + if specCmd.Flag("summary") == nil { + t.Fatal("expected --summary flag on spec command") + } + if specCmd.Flag("description") == nil { + t.Fatal("expected --description flag on spec command") + } if specCmd.Flag("version") == nil { t.Fatal("expected --version flag on spec command") } + if specCmd.Flag("swagger-path") == nil { + t.Fatal("expected --swagger-path flag on spec command") + } + if specCmd.Flag("graphql-path") == nil { + t.Fatal("expected --graphql-path flag on spec command") + } + if specCmd.Flag("graphql-playground") == nil { + t.Fatal("expected --graphql-playground flag on spec command") + } + if specCmd.Flag("graphql-playground-path") == nil { + t.Fatal("expected --graphql-playground-path flag on spec command") + } + if specCmd.Flag("sse-path") == nil { + t.Fatal("expected --sse-path flag on spec command") + } + if specCmd.Flag("ws-path") == nil { + t.Fatal("expected --ws-path flag on spec command") + } + if specCmd.Flag("pprof") == nil { + t.Fatal("expected --pprof flag on spec command") + } + if specCmd.Flag("expvar") == nil { + t.Fatal("expected --expvar flag on spec command") + } + if specCmd.Flag("cache") == nil { + t.Fatal("expected --cache flag on spec command") + } + if specCmd.Flag("cache-ttl") == nil { + t.Fatal("expected --cache-ttl flag on spec command") + } + if specCmd.Flag("cache-max-entries") == nil { + t.Fatal("expected --cache-max-entries flag on spec command") + } + if specCmd.Flag("cache-max-bytes") == nil { + t.Fatal("expected --cache-max-bytes flag on spec command") + } + if specCmd.Flag("i18n-default-locale") == nil { + t.Fatal("expected --i18n-default-locale flag on spec command") + } + if specCmd.Flag("i18n-supported-locales") == nil { + t.Fatal("expected --i18n-supported-locales flag on spec command") + } + if specCmd.Flag("terms-of-service") == nil { + t.Fatal("expected --terms-of-service flag on spec command") + } + if specCmd.Flag("contact-name") == nil { + t.Fatal("expected --contact-name flag on spec command") + } + if specCmd.Flag("contact-url") == nil { + t.Fatal("expected --contact-url flag on spec command") + } + if specCmd.Flag("contact-email") == nil { + t.Fatal("expected --contact-email flag on spec command") + } + if specCmd.Flag("license-name") == nil { + t.Fatal("expected --license-name flag on spec command") + } + if specCmd.Flag("license-url") == nil { + t.Fatal("expected --license-url flag on spec command") + } + if specCmd.Flag("external-docs-description") == nil { + t.Fatal("expected --external-docs-description flag on spec command") + } + if specCmd.Flag("external-docs-url") == nil { + t.Fatal("expected --external-docs-url flag on spec command") + } + if specCmd.Flag("server") == nil { + t.Fatal("expected --server flag on spec command") + } + if specCmd.Flag("security-schemes") == nil { + t.Fatal("expected --security-schemes flag on spec command") + } +} + +func TestAPISpecCmd_Good_CustomDescription(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{"api", "spec", "--description", "Custom API description", "--swagger-path", "/docs", "--output", outputFile}) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + if got := spec["x-swagger-ui-path"]; got != "/docs" { + t.Fatalf("expected x-swagger-ui-path=/docs, got %v", got) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["description"] != "Custom API description" { + t.Fatalf("expected custom description, got %v", info["description"]) + } +} + +func TestAPISpecCmd_Good_SummaryPopulatesSpecInfo(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--summary", "Short API overview", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["summary"] != "Short API overview" { + t.Fatalf("expected summary to be preserved, got %v", info["summary"]) + } +} + +func TestNewSpecBuilder_Good_TrimsMetadata(t *testing.T) { + builder, err := newSpecBuilder(specBuilderConfig{ + title: " API Title ", + summary: " API Summary ", + description: " API Description ", + version: " 1.2.3 ", + termsURL: " https://example.com/terms ", + contactName: " API Support ", + contactURL: " https://example.com/support ", + contactEmail: " support@example.com ", + licenseName: " EUPL-1.2 ", + licenseURL: " https://eupl.eu/1.2/en/ ", + externalDocsDescription: " Developer guide ", + externalDocsURL: " https://example.com/docs ", + servers: " https://api.example.com , / ", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if builder.Title != "API Title" { + t.Fatalf("expected trimmed title, got %q", builder.Title) + } + if builder.Summary != "API Summary" { + t.Fatalf("expected trimmed summary, got %q", builder.Summary) + } + if builder.Description != "API Description" { + t.Fatalf("expected trimmed description, got %q", builder.Description) + } + if builder.Version != "1.2.3" { + t.Fatalf("expected trimmed version, got %q", builder.Version) + } + if builder.TermsOfService != "https://example.com/terms" { + t.Fatalf("expected trimmed terms URL, got %q", builder.TermsOfService) + } + if builder.ContactName != "API Support" { + t.Fatalf("expected trimmed contact name, got %q", builder.ContactName) + } + if builder.ContactURL != "https://example.com/support" { + t.Fatalf("expected trimmed contact URL, got %q", builder.ContactURL) + } + if builder.ContactEmail != "support@example.com" { + t.Fatalf("expected trimmed contact email, got %q", builder.ContactEmail) + } + if builder.LicenseName != "EUPL-1.2" { + t.Fatalf("expected trimmed licence name, got %q", builder.LicenseName) + } + if builder.LicenseURL != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected trimmed licence URL, got %q", builder.LicenseURL) + } + if builder.ExternalDocsDescription != "Developer guide" { + t.Fatalf("expected trimmed external docs description, got %q", builder.ExternalDocsDescription) + } + if builder.ExternalDocsURL != "https://example.com/docs" { + t.Fatalf("expected trimmed external docs URL, got %q", builder.ExternalDocsURL) + } + if len(builder.Servers) != 2 || builder.Servers[0] != "https://api.example.com" || builder.Servers[1] != "/" { + t.Fatalf("expected trimmed servers, got %v", builder.Servers) + } +} + +func TestAPISpecCmd_Good_CacheAndI18nFlagsPopulateSpec(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--cache", + "--cache-ttl", "5m0s", + "--cache-max-entries", "42", + "--cache-max-bytes", "8192", + "--i18n-default-locale", "en-GB", + "--i18n-supported-locales", "en-GB,fr,en-GB", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + if got := spec["x-cache-enabled"]; got != true { + t.Fatalf("expected x-cache-enabled=true, got %v", got) + } + if got := spec["x-cache-ttl"]; got != "5m0s" { + t.Fatalf("expected x-cache-ttl=5m0s, got %v", got) + } + if got := spec["x-cache-max-entries"]; got != float64(42) { + t.Fatalf("expected x-cache-max-entries=42, got %v", got) + } + if got := spec["x-cache-max-bytes"]; got != float64(8192) { + t.Fatalf("expected x-cache-max-bytes=8192, got %v", got) + } + if got := spec["x-i18n-default-locale"]; got != "en-GB" { + t.Fatalf("expected x-i18n-default-locale=en-GB, got %v", got) + } + locales, ok := spec["x-i18n-supported-locales"].([]any) + if !ok { + t.Fatalf("expected x-i18n-supported-locales array, got %T", spec["x-i18n-supported-locales"]) + } + if len(locales) != 2 || locales[0] != "en-GB" || locales[1] != "fr" { + t.Fatalf("expected supported locales [en-GB fr], got %v", locales) + } +} + +func TestNewSpecBuilder_Good_IgnoresNonPositiveCacheTTL(t *testing.T) { + builder, err := newSpecBuilder(specBuilderConfig{ + cacheTTL: "0s", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if builder.CacheEnabled { + t.Fatal("expected non-positive cache TTL to keep cache disabled") + } + if builder.CacheTTL != "0s" { + t.Fatalf("expected cache TTL metadata to be preserved, got %q", builder.CacheTTL) + } +} + +func TestNewSpecBuilder_Good_IgnoresCacheLimitsWithoutPositiveTTL(t *testing.T) { + builder, err := newSpecBuilder(specBuilderConfig{ + cacheMaxEntries: 42, + cacheMaxBytes: 8192, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if builder.CacheEnabled { + t.Fatal("expected cache limits without a positive TTL to keep cache disabled") + } + if builder.CacheMaxEntries != 42 { + t.Fatalf("expected cache max entries metadata to be preserved, got %d", builder.CacheMaxEntries) + } + if builder.CacheMaxBytes != 8192 { + t.Fatalf("expected cache max bytes metadata to be preserved, got %d", builder.CacheMaxBytes) + } +} + +func TestAPISpecCmd_Good_OmitsNonPositiveCacheTTLExtension(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--cache-ttl", "0s", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + if _, ok := spec["x-cache-ttl"]; ok { + t.Fatal("expected non-positive cache TTL to be omitted from generated spec") + } + if got := spec["x-cache-enabled"]; got != nil && got != false { + t.Fatalf("expected cache to remain disabled, got %v", got) + } +} + +func TestAPISpecCmd_Good_GraphQLPlaygroundFlagPopulatesSpecPaths(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--graphql-path", "/graphql", + "--graphql-playground", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatal("expected paths object in generated spec") + } + if _, ok := paths["/graphql/playground"]; !ok { + t.Fatal("expected GraphQL playground path in generated spec") + } +} + +func TestAPISpecCmd_Good_GraphQLPlaygroundPathFlagOverridesGeneratedPath(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--graphql-path", "/graphql", + "--graphql-playground", + "--graphql-playground-path", "/graphql-ui", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatal("expected paths object in generated spec") + } + if _, ok := paths["/graphql-ui"]; !ok { + t.Fatal("expected custom GraphQL playground path in generated spec") + } + if _, ok := paths["/graphql/playground"]; ok { + t.Fatal("expected default GraphQL playground path to be overridden") + } + + if got := spec["x-graphql-playground-path"]; got != "/graphql-ui" { + t.Fatalf("expected x-graphql-playground-path=/graphql-ui, got %v", got) + } +} + +func TestAPISpecCmd_Good_EnabledExtensionsFollowProvidedPaths(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--swagger-path", "/docs", + "--graphql-path", "/graphql", + "--ws-path", "/socket", + "--sse-path", "/events", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + if got := spec["x-swagger-enabled"]; got != true { + t.Fatalf("expected x-swagger-enabled=true, got %v", got) + } + if got := spec["x-graphql-enabled"]; got != true { + t.Fatalf("expected x-graphql-enabled=true, got %v", got) + } + if got := spec["x-ws-enabled"]; got != true { + t.Fatalf("expected x-ws-enabled=true, got %v", got) + } + if got := spec["x-sse-enabled"]; got != true { + t.Fatalf("expected x-sse-enabled=true, got %v", got) + } +} + +func TestAPISpecCmd_Good_AuthentikPublicPathsAreNormalised(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--authentik-public-paths", " /public/ ,docs,/public", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["x-authentik-public-paths"].([]any) + if !ok { + t.Fatalf("expected x-authentik-public-paths array, got %T", spec["x-authentik-public-paths"]) + } + if len(paths) != 4 || paths[0] != "/health" || paths[1] != "/swagger" || paths[2] != "/public" || paths[3] != "/docs" { + t.Fatalf("expected normalised public paths [/health /swagger /public /docs], got %v", paths) + } +} + +func TestAPISpecCmd_Good_ContactFlagsPopulateSpecInfo(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--contact-name", "API Support", + "--contact-url", "https://example.com/support", + "--contact-email", "support@example.com", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in generated spec") + } + if contact["name"] != "API Support" { + t.Fatalf("expected contact name API Support, got %v", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected contact url to be preserved, got %v", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected contact email to be preserved, got %v", contact["email"]) + } +} + +func TestAPISpecCmd_Good_SecuritySchemesFlagPopulatesSpecComponents(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--security-schemes", `{"apiKeyAuth":{"type":"apiKey","in":"header","name":"X-API-Key"}}`, + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + securitySchemes, ok := spec["components"].(map[string]any)["securitySchemes"].(map[string]any) + if !ok { + t.Fatal("expected securitySchemes object in generated spec") + } + apiKeyAuth, ok := securitySchemes["apiKeyAuth"].(map[string]any) + if !ok { + t.Fatal("expected apiKeyAuth security scheme in generated spec") + } + if apiKeyAuth["type"] != "apiKey" { + t.Fatalf("expected apiKeyAuth.type=apiKey, got %v", apiKeyAuth["type"]) + } + if apiKeyAuth["in"] != "header" { + t.Fatalf("expected apiKeyAuth.in=header, got %v", apiKeyAuth["in"]) + } + if apiKeyAuth["name"] != "X-API-Key" { + t.Fatalf("expected apiKeyAuth.name=X-API-Key, got %v", apiKeyAuth["name"]) + } +} + +func TestSpecGroupsIter_Good_DeduplicatesExtraBridge(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + group := specCmdStubGroup{} + api.RegisterSpecGroups(group) + + var groups []api.RouteGroup + for g := range specGroupsIter(group) { + groups = append(groups, g) + } + + if len(groups) != 1 { + t.Fatalf("expected duplicate extra group to be skipped, got %d groups", len(groups)) + } + if groups[0].Name() != group.Name() || groups[0].BasePath() != group.BasePath() { + t.Fatalf("expected original group to be preserved, got %s at %s", groups[0].Name(), groups[0].BasePath()) + } +} + +func TestAPISpecCmd_Good_TermsOfServiceFlagPopulatesSpecInfo(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--terms-of-service", "https://example.com/terms", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService to be preserved, got %v", info["termsOfService"]) + } +} + +func TestAPISpecCmd_Good_ExternalDocsFlagsPopulateSpec(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--external-docs-description", "Developer guide", + "--external-docs-url", "https://example.com/docs", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + externalDocs, ok := spec["externalDocs"].(map[string]any) + if !ok { + t.Fatal("expected externalDocs metadata in generated spec") + } + if externalDocs["description"] != "Developer guide" { + t.Fatalf("expected externalDocs description Developer guide, got %v", externalDocs["description"]) + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected externalDocs url to be preserved, got %v", externalDocs["url"]) + } +} + +func TestAPISpecCmd_Good_ServerFlagAddsServers(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{"api", "spec", "--server", "https://api.example.com, /, https://api.example.com, ", "--output", outputFile}) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + servers, ok := spec["servers"].([]any) + if !ok { + t.Fatalf("expected servers array in generated spec, got %T", spec["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(servers)) + } + if servers[0].(map[string]any)["url"] != "https://api.example.com" { + t.Fatalf("expected first server to be https://api.example.com, got %v", servers[0]) + } + if servers[1].(map[string]any)["url"] != "/" { + t.Fatalf("expected second server to be /, got %v", servers[1]) + } +} + +func TestAPISpecCmd_Good_RegisteredSpecGroups(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + api.RegisterSpecGroups(specCmdStubGroup{}) + + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{"api", "spec", "--output", outputFile}) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + + if _, ok := paths["/registered/ping"]; !ok { + t.Fatal("expected registered route group path in generated spec") + } +} + +func TestAPISpecCmd_Good_LicenseFlagsPopulateSpecInfo(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--license-name", "EUPL-1.2", + "--license-url", "https://eupl.eu/1.2/en/", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected license metadata in generated spec") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected licence name EUPL-1.2, got %v", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected licence url to be preserved, got %v", license["url"]) + } +} + +func TestAPISpecCmd_Good_GraphQLPathPopulatesSpec(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--graphql-path", "/gql", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + + if _, ok := paths["/gql"]; !ok { + t.Fatal("expected GraphQL path to be included in generated spec") + } +} + +func TestAPISpecCmd_Good_SSEPathPopulatesSpec(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--sse-path", "/events", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + + if _, ok := paths["/events"]; !ok { + t.Fatal("expected SSE path to be included in generated spec") + } +} + +func TestAPISpecCmd_Good_RuntimePathsPopulatedSpec(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--ws-path", "/ws", + "--pprof", + "--expvar", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + + if _, ok := paths["/ws"]; !ok { + t.Fatal("expected WebSocket path to be included in generated spec") + } + if _, ok := paths["/debug/pprof"]; !ok { + t.Fatal("expected pprof path to be included in generated spec") + } + if _, ok := paths["/debug/vars"]; !ok { + t.Fatal("expected expvar path to be included in generated spec") + } +} + +func TestAPISpecCmd_Good_AuthentikFlagsPopulateSpecMetadata(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + outputFile := t.TempDir() + "/spec.json" + root.SetArgs([]string{ + "api", "spec", + "--authentik-issuer", "https://auth.example.com", + "--authentik-client-id", "core-client", + "--authentik-trusted-proxy", + "--authentik-public-paths", "/public, /docs, /public", + "--output", outputFile, + }) + root.SetErr(new(bytes.Buffer)) + + if err := root.Execute(); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + if got := spec["x-authentik-issuer"]; got != "https://auth.example.com" { + t.Fatalf("expected x-authentik-issuer=https://auth.example.com, got %v", got) + } + if got := spec["x-authentik-client-id"]; got != "core-client" { + t.Fatalf("expected x-authentik-client-id=core-client, got %v", got) + } + if got := spec["x-authentik-trusted-proxy"]; got != true { + t.Fatalf("expected x-authentik-trusted-proxy=true, got %v", got) + } + publicPaths, ok := spec["x-authentik-public-paths"].([]any) + if !ok { + t.Fatalf("expected x-authentik-public-paths array, got %T", spec["x-authentik-public-paths"]) + } + if len(publicPaths) != 4 || publicPaths[0] != "/health" || publicPaths[1] != "/swagger" || publicPaths[2] != "/public" || publicPaths[3] != "/docs" { + t.Fatalf("expected public paths [/health /swagger /public /docs], got %v", publicPaths) + } +} + +func TestAPISDKCmd_Bad_EmptyLanguages(t *testing.T) { + root := &cli.Command{Use: "root"} + AddAPICommands(root) + + root.SetArgs([]string{"api", "sdk", "--lang", " , , "}) + buf := new(bytes.Buffer) + root.SetOut(buf) + root.SetErr(buf) + + err := root.Execute() + if err == nil { + t.Fatal("expected error when --lang only contains empty values") + } } func TestAPISDKCmd_Bad_NoLang(t *testing.T) { @@ -98,4 +1114,309 @@ func TestAPISDKCmd_Good_ValidatesLanguage(t *testing.T) { if sdkCmd.Flag("package") == nil { t.Fatal("expected --package flag on sdk command") } + if sdkCmd.Flag("title") == nil { + t.Fatal("expected --title flag on sdk command") + } + if sdkCmd.Flag("description") == nil { + t.Fatal("expected --description flag on sdk command") + } + if sdkCmd.Flag("version") == nil { + t.Fatal("expected --version flag on sdk command") + } + if sdkCmd.Flag("swagger-path") == nil { + t.Fatal("expected --swagger-path flag on sdk command") + } + if sdkCmd.Flag("graphql-path") == nil { + t.Fatal("expected --graphql-path flag on sdk command") + } + if sdkCmd.Flag("sse-path") == nil { + t.Fatal("expected --sse-path flag on sdk command") + } + if sdkCmd.Flag("graphql-playground-path") == nil { + t.Fatal("expected --graphql-playground-path flag on sdk command") + } + if sdkCmd.Flag("ws-path") == nil { + t.Fatal("expected --ws-path flag on sdk command") + } + if sdkCmd.Flag("pprof") == nil { + t.Fatal("expected --pprof flag on sdk command") + } + if sdkCmd.Flag("expvar") == nil { + t.Fatal("expected --expvar flag on sdk command") + } + if sdkCmd.Flag("cache") == nil { + t.Fatal("expected --cache flag on sdk command") + } + if sdkCmd.Flag("cache-ttl") == nil { + t.Fatal("expected --cache-ttl flag on sdk command") + } + if sdkCmd.Flag("cache-max-entries") == nil { + t.Fatal("expected --cache-max-entries flag on sdk command") + } + if sdkCmd.Flag("cache-max-bytes") == nil { + t.Fatal("expected --cache-max-bytes flag on sdk command") + } + if sdkCmd.Flag("i18n-default-locale") == nil { + t.Fatal("expected --i18n-default-locale flag on sdk command") + } + if sdkCmd.Flag("i18n-supported-locales") == nil { + t.Fatal("expected --i18n-supported-locales flag on sdk command") + } + if sdkCmd.Flag("authentik-issuer") == nil { + t.Fatal("expected --authentik-issuer flag on sdk command") + } + if sdkCmd.Flag("authentik-client-id") == nil { + t.Fatal("expected --authentik-client-id flag on sdk command") + } + if sdkCmd.Flag("authentik-trusted-proxy") == nil { + t.Fatal("expected --authentik-trusted-proxy flag on sdk command") + } + if sdkCmd.Flag("authentik-public-paths") == nil { + t.Fatal("expected --authentik-public-paths flag on sdk command") + } + if sdkCmd.Flag("terms-of-service") == nil { + t.Fatal("expected --terms-of-service flag on sdk command") + } + if sdkCmd.Flag("contact-name") == nil { + t.Fatal("expected --contact-name flag on sdk command") + } + if sdkCmd.Flag("contact-url") == nil { + t.Fatal("expected --contact-url flag on sdk command") + } + if sdkCmd.Flag("contact-email") == nil { + t.Fatal("expected --contact-email flag on sdk command") + } + if sdkCmd.Flag("license-name") == nil { + t.Fatal("expected --license-name flag on sdk command") + } + if sdkCmd.Flag("license-url") == nil { + t.Fatal("expected --license-url flag on sdk command") + } + if sdkCmd.Flag("server") == nil { + t.Fatal("expected --server flag on sdk command") + } + if sdkCmd.Flag("security-schemes") == nil { + t.Fatal("expected --security-schemes flag on sdk command") + } +} + +func TestAPISDKCmd_Good_TempSpecUsesMetadataFlags(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + api.RegisterSpecGroups(specCmdStubGroup{}) + + builder, err := sdkSpecBuilder(specBuilderConfig{ + title: "Custom SDK API", + summary: "Custom SDK overview", + description: "Custom SDK description", + version: "9.9.9", + swaggerPath: "/docs", + graphqlPath: "/gql", + graphqlPlayground: true, + graphqlPlaygroundPath: "/gql/ide", + ssePath: "/events", + wsPath: "/ws", + pprofEnabled: true, + expvarEnabled: true, + cacheEnabled: true, + cacheTTL: "5m0s", + cacheMaxEntries: 42, + cacheMaxBytes: 8192, + i18nDefaultLocale: "en-GB", + i18nSupportedLocales: "en-GB,fr,en-GB", + authentikIssuer: "https://auth.example.com", + authentikClientID: "core-client", + authentikTrustedProxy: true, + authentikPublicPaths: "/public, /docs, /public", + termsURL: "https://example.com/terms", + contactName: "SDK Support", + contactURL: "https://example.com/support", + contactEmail: "support@example.com", + licenseName: "EUPL-1.2", + licenseURL: "https://eupl.eu/1.2/en/", + servers: "https://api.example.com, /, https://api.example.com", + securitySchemes: `{"apiKeyAuth":{"type":"apiKey","in":"header","name":"X-API-Key"}}`, + }) + if err != nil { + t.Fatalf("unexpected error building sdk spec: %v", err) + } + if builder.GraphQLPlaygroundPath != "/gql/ide" { + t.Fatalf("expected custom GraphQL playground path to be preserved in builder, got %q", builder.GraphQLPlaygroundPath) + } + groups := collectRouteGroups(sdkSpecGroupsIter()) + + outputFile := t.TempDir() + "/spec.json" + if err := api.ExportSpecToFile(outputFile, "json", builder, groups); err != nil { + t.Fatalf("unexpected error writing temp spec: %v", err) + } + + data, err := os.ReadFile(outputFile) + if err != nil { + t.Fatalf("expected spec file to be written: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("expected valid JSON spec, got error: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["title"] != "Custom SDK API" { + t.Fatalf("expected custom title, got %v", info["title"]) + } + if info["description"] != "Custom SDK description" { + t.Fatalf("expected custom description, got %v", info["description"]) + } + if info["summary"] != "Custom SDK overview" { + t.Fatalf("expected custom summary, got %v", info["summary"]) + } + if info["version"] != "9.9.9" { + t.Fatalf("expected custom version, got %v", info["version"]) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + if _, ok := paths["/gql"]; !ok { + t.Fatal("expected GraphQL path to be included in generated spec") + } + if got := builder.SwaggerPath; got != "/docs" { + t.Fatalf("expected swagger path to be preserved in sdk spec builder, got %v", got) + } + if _, ok := paths["/gql/ide"]; !ok { + t.Fatalf("expected custom GraphQL playground path to be included in generated spec, got keys %v", paths) + } + if _, ok := paths["/gql/playground"]; ok { + t.Fatal("expected custom GraphQL playground path to replace the default playground path") + } + if _, ok := paths["/events"]; !ok { + t.Fatal("expected SSE path to be included in generated spec") + } + if _, ok := paths["/ws"]; !ok { + t.Fatal("expected WebSocket path to be included in generated spec") + } + if _, ok := paths["/debug/pprof"]; !ok { + t.Fatal("expected pprof path to be included in generated spec") + } + if _, ok := paths["/debug/vars"]; !ok { + t.Fatal("expected expvar path to be included in generated spec") + } + + if got := spec["x-cache-enabled"]; got != true { + t.Fatalf("expected x-cache-enabled=true, got %v", got) + } + if got := spec["x-cache-ttl"]; got != "5m0s" { + t.Fatalf("expected x-cache-ttl=5m0s, got %v", got) + } + if got := spec["x-cache-max-entries"]; got != float64(42) { + t.Fatalf("expected x-cache-max-entries=42, got %v", got) + } + if got := spec["x-cache-max-bytes"]; got != float64(8192) { + t.Fatalf("expected x-cache-max-bytes=8192, got %v", got) + } + if got := spec["x-i18n-default-locale"]; got != "en-GB" { + t.Fatalf("expected x-i18n-default-locale=en-GB, got %v", got) + } + locales, ok := spec["x-i18n-supported-locales"].([]any) + if !ok { + t.Fatalf("expected x-i18n-supported-locales array, got %T", spec["x-i18n-supported-locales"]) + } + if len(locales) != 2 || locales[0] != "en-GB" || locales[1] != "fr" { + t.Fatalf("expected supported locales [en-GB fr], got %v", locales) + } + if got := spec["x-authentik-issuer"]; got != "https://auth.example.com" { + t.Fatalf("expected x-authentik-issuer=https://auth.example.com, got %v", got) + } + if got := spec["x-authentik-client-id"]; got != "core-client" { + t.Fatalf("expected x-authentik-client-id=core-client, got %v", got) + } + if got := spec["x-authentik-trusted-proxy"]; got != true { + t.Fatalf("expected x-authentik-trusted-proxy=true, got %v", got) + } + publicPaths, ok := spec["x-authentik-public-paths"].([]any) + if !ok { + t.Fatalf("expected x-authentik-public-paths array, got %T", spec["x-authentik-public-paths"]) + } + if len(publicPaths) != 4 || publicPaths[0] != "/health" || publicPaths[1] != "/swagger" || publicPaths[2] != "/docs" || publicPaths[3] != "/public" { + t.Fatalf("expected public paths [/health /swagger /docs /public], got %v", publicPaths) + } + + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService to be preserved, got %v", info["termsOfService"]) + } + + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in generated spec") + } + if contact["name"] != "SDK Support" { + t.Fatalf("expected contact name SDK Support, got %v", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected contact url to be preserved, got %v", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected contact email to be preserved, got %v", contact["email"]) + } + + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected licence metadata in generated spec") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected licence name EUPL-1.2, got %v", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected licence url to be preserved, got %v", license["url"]) + } + + servers, ok := spec["servers"].([]any) + if !ok { + t.Fatalf("expected servers array in generated spec, got %T", spec["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(servers)) + } + if servers[0].(map[string]any)["url"] != "https://api.example.com" { + t.Fatalf("expected first server to be https://api.example.com, got %v", servers[0]) + } + if servers[1].(map[string]any)["url"] != "/" { + t.Fatalf("expected second server to be /, got %v", servers[1]) + } + + securitySchemes, ok := spec["components"].(map[string]any)["securitySchemes"].(map[string]any) + if !ok { + t.Fatal("expected securitySchemes in generated spec") + } + if _, ok := securitySchemes["apiKeyAuth"].(map[string]any); !ok { + t.Fatalf("expected apiKeyAuth security scheme in generated spec, got %v", securitySchemes) + } +} + +func TestAPISDKCmd_Good_SpecGroupsDeduplicateToolBridge(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + api.RegisterSpecGroups(api.NewToolBridge("/tools")) + + groups := collectRouteGroups(sdkSpecGroupsIter()) + if len(groups) != 1 { + t.Fatalf("expected the built-in tools bridge to be deduplicated, got %d groups", len(groups)) + } + if groups[0].BasePath() != "/tools" { + t.Fatalf("expected the remaining group to be /tools, got %s", groups[0].BasePath()) + } } diff --git a/cmd/api/spec_builder.go b/cmd/api/spec_builder.go new file mode 100644 index 0000000..bdd32b3 --- /dev/null +++ b/cmd/api/spec_builder.go @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "strings" + "time" + + goapi "dappco.re/go/core/api" +) + +type specBuilderConfig struct { + title string + summary string + description string + version string + swaggerPath string + graphqlPath string + graphqlPlayground bool + graphqlPlaygroundPath string + ssePath string + wsPath string + pprofEnabled bool + expvarEnabled bool + cacheEnabled bool + cacheTTL string + cacheMaxEntries int + cacheMaxBytes int + i18nDefaultLocale string + i18nSupportedLocales string + authentikIssuer string + authentikClientID string + authentikTrustedProxy bool + authentikPublicPaths string + termsURL string + contactName string + contactURL string + contactEmail string + licenseName string + licenseURL string + externalDocsDescription string + externalDocsURL string + servers string + securitySchemes string +} + +func newSpecBuilder(cfg specBuilderConfig) (*goapi.SpecBuilder, error) { + swaggerPath := strings.TrimSpace(cfg.swaggerPath) + graphqlPath := strings.TrimSpace(cfg.graphqlPath) + ssePath := strings.TrimSpace(cfg.ssePath) + wsPath := strings.TrimSpace(cfg.wsPath) + cacheTTL := strings.TrimSpace(cfg.cacheTTL) + cacheTTLValid := parsePositiveDuration(cacheTTL) + + builder := &goapi.SpecBuilder{ + Title: strings.TrimSpace(cfg.title), + Summary: strings.TrimSpace(cfg.summary), + Description: strings.TrimSpace(cfg.description), + Version: strings.TrimSpace(cfg.version), + SwaggerEnabled: swaggerPath != "", + SwaggerPath: swaggerPath, + GraphQLEnabled: graphqlPath != "" || cfg.graphqlPlayground, + GraphQLPath: graphqlPath, + GraphQLPlayground: cfg.graphqlPlayground, + GraphQLPlaygroundPath: strings.TrimSpace(cfg.graphqlPlaygroundPath), + SSEEnabled: ssePath != "", + SSEPath: ssePath, + WSEnabled: wsPath != "", + WSPath: wsPath, + PprofEnabled: cfg.pprofEnabled, + ExpvarEnabled: cfg.expvarEnabled, + CacheEnabled: cfg.cacheEnabled || cacheTTLValid, + CacheTTL: cacheTTL, + CacheMaxEntries: cfg.cacheMaxEntries, + CacheMaxBytes: cfg.cacheMaxBytes, + I18nDefaultLocale: strings.TrimSpace(cfg.i18nDefaultLocale), + TermsOfService: strings.TrimSpace(cfg.termsURL), + ContactName: strings.TrimSpace(cfg.contactName), + ContactURL: strings.TrimSpace(cfg.contactURL), + ContactEmail: strings.TrimSpace(cfg.contactEmail), + Servers: parseServers(cfg.servers), + LicenseName: strings.TrimSpace(cfg.licenseName), + LicenseURL: strings.TrimSpace(cfg.licenseURL), + ExternalDocsDescription: strings.TrimSpace(cfg.externalDocsDescription), + ExternalDocsURL: strings.TrimSpace(cfg.externalDocsURL), + AuthentikIssuer: strings.TrimSpace(cfg.authentikIssuer), + AuthentikClientID: strings.TrimSpace(cfg.authentikClientID), + AuthentikTrustedProxy: cfg.authentikTrustedProxy, + AuthentikPublicPaths: normalisePublicPaths(splitUniqueCSV(cfg.authentikPublicPaths)), + } + + builder.I18nSupportedLocales = parseLocales(cfg.i18nSupportedLocales) + if builder.I18nDefaultLocale == "" && len(builder.I18nSupportedLocales) > 0 { + builder.I18nDefaultLocale = "en" + } + + if cfg.securitySchemes != "" { + schemes, err := parseSecuritySchemes(cfg.securitySchemes) + if err != nil { + return nil, err + } + builder.SecuritySchemes = schemes + } + + return builder, nil +} + +func parseLocales(raw string) []string { + return splitUniqueCSV(raw) +} + +func parsePositiveDuration(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + + d, err := time.ParseDuration(raw) + if err != nil || d <= 0 { + return false + } + + return true +} diff --git a/cmd/api/spec_groups_iter.go b/cmd/api/spec_groups_iter.go new file mode 100644 index 0000000..208de61 --- /dev/null +++ b/cmd/api/spec_groups_iter.go @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "iter" + + goapi "dappco.re/go/core/api" +) + +// specGroupsIter snapshots the registered spec groups and appends one optional +// extra group. It keeps the command paths iterator-backed while preserving the +// existing ordering guarantees. +func specGroupsIter(extra goapi.RouteGroup) iter.Seq[goapi.RouteGroup] { + return goapi.SpecGroupsIter(extra) +} diff --git a/codegen.go b/codegen.go index b8cb12e..e031dea 100644 --- a/codegen.go +++ b/codegen.go @@ -11,6 +11,7 @@ import ( "os/exec" "path/filepath" "slices" + "strings" coreio "dappco.re/go/core/io" coreerr "dappco.re/go/core/log" @@ -32,6 +33,10 @@ var supportedLanguages = map[string]string{ } // SDKGenerator wraps openapi-generator-cli for SDK generation. +// +// Example: +// +// gen := &api.SDKGenerator{SpecPath: "./openapi.yaml", OutputDir: "./sdk", PackageName: "service"} type SDKGenerator struct { // SpecPath is the path to the OpenAPI spec file (JSON or YAML). SpecPath string @@ -45,22 +50,50 @@ type SDKGenerator struct { // Generate creates an SDK for the given language using openapi-generator-cli. // The language must be one of the supported languages returned by SupportedLanguages(). +// +// Example: +// +// err := gen.Generate(context.Background(), "go") func (g *SDKGenerator) Generate(ctx context.Context, language string) error { + if g == nil { + return coreerr.E("SDKGenerator.Generate", "generator is nil", nil) + } + if ctx == nil { + return coreerr.E("SDKGenerator.Generate", "context is nil", nil) + } + + language = strings.TrimSpace(language) generator, ok := supportedLanguages[language] if !ok { return coreerr.E("SDKGenerator.Generate", fmt.Sprintf("unsupported language %q: supported languages are %v", language, SupportedLanguages()), nil) } - if _, err := os.Stat(g.SpecPath); os.IsNotExist(err) { - return coreerr.E("SDKGenerator.Generate", "spec file not found: "+g.SpecPath, nil) + specPath := strings.TrimSpace(g.SpecPath) + if specPath == "" { + return coreerr.E("SDKGenerator.Generate", "spec path is required", nil) + } + if _, err := os.Stat(specPath); err != nil { + if os.IsNotExist(err) { + return coreerr.E("SDKGenerator.Generate", "spec file not found: "+specPath, nil) + } + return coreerr.E("SDKGenerator.Generate", "stat spec file", err) + } + + outputBase := strings.TrimSpace(g.OutputDir) + if outputBase == "" { + return coreerr.E("SDKGenerator.Generate", "output directory is required", nil) + } + + if !g.Available() { + return coreerr.E("SDKGenerator.Generate", "openapi-generator-cli not installed", nil) } - outputDir := filepath.Join(g.OutputDir, language) + outputDir := filepath.Join(outputBase, language) if err := coreio.Local.EnsureDir(outputDir); err != nil { return coreerr.E("SDKGenerator.Generate", "create output directory", err) } - args := g.buildArgs(generator, outputDir) + args := g.buildArgs(specPath, generator, outputDir) cmd := exec.CommandContext(ctx, "openapi-generator-cli", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -73,10 +106,10 @@ func (g *SDKGenerator) Generate(ctx context.Context, language string) error { } // buildArgs constructs the openapi-generator-cli command arguments. -func (g *SDKGenerator) buildArgs(generator, outputDir string) []string { +func (g *SDKGenerator) buildArgs(specPath, generator, outputDir string) []string { args := []string{ "generate", - "-i", g.SpecPath, + "-i", specPath, "-g", generator, "-o", outputDir, } @@ -87,6 +120,12 @@ func (g *SDKGenerator) buildArgs(generator, outputDir string) []string { } // Available checks if openapi-generator-cli is installed and accessible. +// +// Example: +// +// if !gen.Available() { +// t.Fatal("openapi-generator-cli is required") +// } func (g *SDKGenerator) Available() bool { _, err := exec.LookPath("openapi-generator-cli") return err == nil @@ -94,11 +133,21 @@ func (g *SDKGenerator) Available() bool { // SupportedLanguages returns the list of supported SDK target languages // in sorted order for deterministic output. +// +// Example: +// +// langs := api.SupportedLanguages() func SupportedLanguages() []string { return slices.Sorted(maps.Keys(supportedLanguages)) } // SupportedLanguagesIter returns an iterator over supported SDK target languages in sorted order. +// +// Example: +// +// for lang := range api.SupportedLanguagesIter() { +// fmt.Println(lang) +// } func SupportedLanguagesIter() iter.Seq[string] { return slices.Values(SupportedLanguages()) } diff --git a/codegen_test.go b/codegen_test.go index dcb058d..5d3b580 100644 --- a/codegen_test.go +++ b/codegen_test.go @@ -59,7 +59,108 @@ func TestSDKGenerator_Bad_MissingSpec(t *testing.T) { } } +func TestSDKGenerator_Bad_EmptySpecPath(t *testing.T) { + gen := &api.SDKGenerator{ + OutputDir: t.TempDir(), + } + + err := gen.Generate(context.Background(), "go") + if err == nil { + t.Fatal("expected error for empty spec path, got nil") + } + if !strings.Contains(err.Error(), "spec path is required") { + t.Fatalf("expected error to contain 'spec path is required', got: %v", err) + } +} + +func TestSDKGenerator_Bad_EmptyOutputDir(t *testing.T) { + specDir := t.TempDir() + specPath := filepath.Join(specDir, "spec.json") + if err := os.WriteFile(specPath, []byte(`{"openapi":"3.1.0"}`), 0o644); err != nil { + t.Fatalf("failed to write spec file: %v", err) + } + + gen := &api.SDKGenerator{ + SpecPath: specPath, + } + + err := gen.Generate(context.Background(), "go") + if err == nil { + t.Fatal("expected error for empty output directory, got nil") + } + if !strings.Contains(err.Error(), "output directory is required") { + t.Fatalf("expected error to contain 'output directory is required', got: %v", err) + } +} + +func TestSDKGenerator_Bad_NilContext(t *testing.T) { + gen := &api.SDKGenerator{ + SpecPath: filepath.Join(t.TempDir(), "nonexistent.json"), + OutputDir: t.TempDir(), + } + + err := gen.Generate(nil, "go") + if err == nil { + t.Fatal("expected error for nil context, got nil") + } + if !strings.Contains(err.Error(), "context is nil") { + t.Fatalf("expected error to contain 'context is nil', got: %v", err) + } +} + +func TestSDKGenerator_Bad_NilReceiver(t *testing.T) { + var gen *api.SDKGenerator + + err := gen.Generate(context.Background(), "go") + if err == nil { + t.Fatal("expected error for nil generator, got nil") + } + if !strings.Contains(err.Error(), "generator is nil") { + t.Fatalf("expected error to contain 'generator is nil', got: %v", err) + } +} + +func TestSDKGenerator_Bad_MissingGenerator(t *testing.T) { + t.Setenv("PATH", t.TempDir()) + + specDir := t.TempDir() + specPath := filepath.Join(specDir, "spec.json") + if err := os.WriteFile(specPath, []byte(`{"openapi":"3.1.0"}`), 0o644); err != nil { + t.Fatalf("failed to write spec file: %v", err) + } + + outputDir := filepath.Join(t.TempDir(), "nested", "sdk") + gen := &api.SDKGenerator{ + SpecPath: specPath, + OutputDir: outputDir, + } + + err := gen.Generate(context.Background(), "go") + if err == nil { + t.Fatal("expected error when openapi-generator-cli is missing, got nil") + } + if !strings.Contains(err.Error(), "openapi-generator-cli not installed") { + t.Fatalf("expected missing-generator error, got: %v", err) + } + + if _, statErr := os.Stat(filepath.Join(outputDir, "go")); !os.IsNotExist(statErr) { + t.Fatalf("expected output directory not to be created when generator is missing, got err=%v", statErr) + } +} + func TestSDKGenerator_Good_OutputDirCreated(t *testing.T) { + oldPath := os.Getenv("PATH") + + // Provide a fake openapi-generator-cli so Generate reaches the exec step + // without depending on the host environment. + binDir := t.TempDir() + binPath := filepath.Join(binDir, "openapi-generator-cli") + script := []byte("#!/bin/sh\nexit 1\n") + if err := os.WriteFile(binPath, script, 0o755); err != nil { + t.Fatalf("failed to write fake generator: %v", err) + } + t.Setenv("PATH", binDir+string(os.PathListSeparator)+oldPath) + // Write a minimal spec file so we pass the file-exists check. specDir := t.TempDir() specPath := filepath.Join(specDir, "spec.json") @@ -73,8 +174,8 @@ func TestSDKGenerator_Good_OutputDirCreated(t *testing.T) { OutputDir: outputDir, } - // Generate will fail at the exec step (openapi-generator-cli likely not installed), - // but the output directory should have been created before that. + // Generate will fail at the exec step, but the output directory should have + // been created before the CLI returned its non-zero status. _ = gen.Generate(context.Background(), "go") expected := filepath.Join(outputDir, "go") diff --git a/docs/architecture.md b/docs/architecture.md index db1deb4..ec0222b 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -30,6 +30,8 @@ type Engine struct { swaggerTitle string swaggerDesc string swaggerVersion string + swaggerExternalDocsDescription string + swaggerExternalDocsURL string pprofEnabled bool expvarEnabled bool graphql *graphqlConfig @@ -128,6 +130,9 @@ type RouteDescription struct { Summary string Description string Tags []string + Deprecated bool + StatusCode int + Parameters []ParameterDescription RequestBody map[string]any Response map[string]any } @@ -151,12 +156,19 @@ They execute after `gin.Recovery()` but before any route handler. The `Option` t | `WithAddr(addr)` | Listen address | Default `:8080` | | `WithBearerAuth(token)` | Static bearer token authentication | Skips `/health` and `/swagger` | | `WithRequestID()` | `X-Request-ID` propagation | Preserves client-supplied IDs; generates 16-byte hex otherwise | +| `WithResponseMeta()` | Request metadata in JSON envelopes | Merges `request_id` and `duration` into standard responses | | `WithCORS(origins...)` | CORS policy | `"*"` enables `AllowAllOrigins`; 12-hour `MaxAge` | +| `WithRateLimit(limit)` | Per-IP token-bucket rate limiting | `429 Too Many Requests`; `X-RateLimit-*` on success; `Retry-After` on rejection; zero or negative disables | | `WithMiddleware(mw...)` | Arbitrary Gin middleware | Escape hatch for custom middleware | | `WithStatic(prefix, root)` | Static file serving | Directory listing disabled | | `WithWSHandler(h)` | WebSocket at `/ws` | Wraps any `http.Handler` | | `WithAuthentik(cfg)` | Authentik forward-auth + OIDC JWT | Permissive; populates context, never rejects | | `WithSwagger(title, desc, ver)` | Swagger UI at `/swagger/` | Runtime spec via `SpecBuilder` | +| `WithSwaggerTermsOfService(url)` | OpenAPI terms of service metadata | Populates the Swagger spec info block without manual `SpecBuilder` wiring | +| `WithSwaggerContact(name, url, email)` | OpenAPI contact metadata | Populates the Swagger spec info block without manual `SpecBuilder` wiring | +| `WithSwaggerServers(servers...)` | OpenAPI server metadata | Feeds the runtime Swagger spec and exported docs | +| `WithSwaggerLicense(name, url)` | OpenAPI licence metadata | Populates the Swagger spec info block without manual `SpecBuilder` wiring | +| `WithSwaggerExternalDocs(description, url)` | OpenAPI external documentation metadata | Populates the top-level `externalDocs` block without manual `SpecBuilder` wiring | | `WithPprof()` | Go profiling at `/debug/pprof/` | WARNING: do not expose in production without authentication | | `WithExpvar()` | Runtime metrics at `/debug/vars` | WARNING: do not expose in production without authentication | | `WithSecure()` | Security headers | HSTS 1 year, X-Frame-Options DENY, nosniff, strict referrer | @@ -164,7 +176,8 @@ They execute after `gin.Recovery()` but before any route handler. The `Option` t | `WithBrotli(level...)` | Brotli response compression | Writer pool for efficiency; default compression if level omitted | | `WithSlog(logger)` | Structured request logging | Falls back to `slog.Default()` if nil | | `WithTimeout(d)` | Per-request deadline | 504 with standard error envelope on timeout | -| `WithCache(ttl)` | In-memory GET response caching | `X-Cache: HIT` header on cache hits; 2xx only | +| `WithCache(ttl)` | In-memory GET response caching | Compatibility wrapper for `WithCacheLimits(ttl, 0, 0)`; `X-Cache: HIT` header on cache hits; 2xx only | +| `WithCacheLimits(ttl, maxEntries, maxBytes)` | In-memory GET response caching with explicit bounds | Clearer cache configuration when eviction policy should be self-documenting | | `WithSessions(name, secret)` | Cookie-backed server sessions | gin-contrib/sessions with cookie store | | `WithAuthz(enforcer)` | Casbin policy-based authorisation | Subject from HTTP Basic Auth; 403 on deny | | `WithHTTPSign(secrets, opts...)` | HTTP Signatures verification | draft-cavage-http-signatures; 401/400 on failure | @@ -371,14 +384,19 @@ redirects and introspection). The GraphQL handler is created via gqlgen's ## 8. Response Caching -`WithCache(ttl)` installs a URL-keyed in-memory response cache scoped to GET requests: +`WithCacheLimits(ttl, maxEntries, maxBytes)` installs a URL-keyed in-memory response cache scoped to GET requests: + +```go +engine, _ := api.New(api.WithCacheLimits(5*time.Minute, 100, 10<<20)) +``` - Only successful 2xx responses are cached. - Non-GET methods pass through uncached. - Cached responses are served with an `X-Cache: HIT` header. - Expired entries are evicted lazily on the next access for the same key. - The cache is not shared across `Engine` instances. -- There is no size limit on the cache. +- `WithCache(ttl)` remains available as a compatibility wrapper for callers that do not need to spell out the bounds. +- Passing non-positive values to `WithCacheLimits` leaves that limit unbounded. The implementation uses a `cacheWriter` that wraps `gin.ResponseWriter` to intercept and capture the response body and status code for storage. @@ -573,7 +591,9 @@ Generates an OpenAPI 3.1 specification from registered route groups. | `--output` | `-o` | (stdout) | Write spec to file | | `--format` | `-f` | `json` | Output format: `json` or `yaml` | | `--title` | `-t` | `Lethean Core API` | API title | +| `--description` | `-d` | `Lethean Core API` | API description | | `--version` | `-V` | `1.0.0` | API version | +| `--server` | `-S` | (none) | Comma-separated OpenAPI server URL(s) | ### `core api sdk` @@ -585,6 +605,10 @@ Generates client SDKs from an OpenAPI spec using `openapi-generator-cli`. | `--output` | `-o` | `./sdk` | Output directory | | `--spec` | `-s` | (auto-generated) | Path to existing OpenAPI spec | | `--package` | `-p` | `lethean` | Package name for generated SDK | +| `--title` | `-t` | `Lethean Core API` | API title in generated spec | +| `--description` | `-d` | `Lethean Core API` | API description in generated spec | +| `--version` | `-V` | `1.0.0` | API version in generated spec | +| `--server` | `-S` | (none) | Comma-separated OpenAPI server URL(s) | --- diff --git a/docs/history.md b/docs/history.md index 823e360..f2a6f81 100644 --- a/docs/history.md +++ b/docs/history.md @@ -169,11 +169,12 @@ At the end of Phase 3, the module has 176 tests. ## Known Limitations -### 1. Cache has no size limit +### 1. Cache remains in-memory -`WithCache(ttl)` stores all successful GET responses in memory with no maximum entry count or -total size bound. For a server receiving requests to many distinct URLs, the cache will grow -without bound. A LRU eviction policy or a configurable maximum is the natural next step. +`WithCache(ttl, maxEntries, maxBytes)` can now bound the cache by entry count and approximate +payload size, but it still stores responses in memory. Workloads with very large cached bodies +or a long-lived process will still consume RAM, so a disk-backed cache would be the next step if +that becomes a concern. ### 2. SDK codegen requires an external binary diff --git a/docs/index.md b/docs/index.md index 3dec037..d4292c3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,6 +44,7 @@ func main() { api.WithSecure(), api.WithSlog(nil), api.WithSwagger("My API", "A service description", "1.0.0"), + api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), ) engine.Register(myRoutes) // any RouteGroup implementation @@ -94,7 +95,7 @@ engine.Register(&Routes{service: svc}) | File | Purpose | |------|---------| | `api.go` | `Engine` struct, `New()`, `build()`, `Serve()`, `Handler()`, `Channels()` | -| `options.go` | All `With*()` option functions (25 options) | +| `options.go` | All `With*()` option functions (28 options) | | `group.go` | `RouteGroup`, `StreamGroup`, `DescribableGroup` interfaces; `RouteDescription` | | `response.go` | `Response[T]`, `Error`, `Meta`, `OK()`, `Fail()`, `FailWithDetails()`, `Paginated()` | | `middleware.go` | `bearerAuthMiddleware()`, `requestIDMiddleware()` | diff --git a/export.go b/export.go index f514ed9..fadff80 100644 --- a/export.go +++ b/export.go @@ -4,9 +4,12 @@ package api import ( "encoding/json" + "fmt" "io" + "iter" "os" "path/filepath" + "strings" "gopkg.in/yaml.v3" @@ -16,43 +19,112 @@ import ( // ExportSpec generates the OpenAPI spec and writes it to w. // Format must be "json" or "yaml". +// +// Example: +// +// _ = api.ExportSpec(os.Stdout, "yaml", builder, engine.Groups()) func ExportSpec(w io.Writer, format string, builder *SpecBuilder, groups []RouteGroup) error { data, err := builder.Build(groups) if err != nil { return coreerr.E("ExportSpec", "build spec", err) } - switch format { + return writeSpec(w, format, data, "ExportSpec") +} + +// ExportSpecIter generates the OpenAPI spec from an iterator and writes it to w. +// Format must be "json" or "yaml". +// +// Example: +// +// _ = api.ExportSpecIter(os.Stdout, "json", builder, api.RegisteredSpecGroupsIter()) +func ExportSpecIter(w io.Writer, format string, builder *SpecBuilder, groups iter.Seq[RouteGroup]) error { + data, err := builder.BuildIter(groups) + if err != nil { + return coreerr.E("ExportSpecIter", "build spec", err) + } + + return writeSpec(w, format, data, "ExportSpecIter") +} + +func writeSpec(w io.Writer, format string, data []byte, op string) error { + switch strings.ToLower(strings.TrimSpace(format)) { case "json": - _, err = w.Write(data) + _, err := w.Write(data) return err case "yaml": // Unmarshal JSON then re-marshal as YAML. var obj any if err := json.Unmarshal(data, &obj); err != nil { - return coreerr.E("ExportSpec", "unmarshal spec", err) + return coreerr.E(op, "unmarshal spec", err) } enc := yaml.NewEncoder(w) enc.SetIndent(2) if err := enc.Encode(obj); err != nil { - return coreerr.E("ExportSpec", "encode yaml", err) + return coreerr.E(op, "encode yaml", err) } return enc.Close() default: - return coreerr.E("ExportSpec", "unsupported format "+format+": use \"json\" or \"yaml\"", nil) + return coreerr.E(op, fmt.Sprintf("unsupported format %s: use %q or %q", format, "json", "yaml"), nil) } } // ExportSpecToFile writes the spec to the given path. // The parent directory is created if it does not exist. +// +// Example: +// +// _ = api.ExportSpecToFile("./api/openapi.yaml", "yaml", builder, engine.Groups()) func ExportSpecToFile(path, format string, builder *SpecBuilder, groups []RouteGroup) error { - if err := coreio.Local.EnsureDir(filepath.Dir(path)); err != nil { - return coreerr.E("ExportSpecToFile", "create directory", err) + return exportSpecToFile(path, "ExportSpecToFile", func(w io.Writer) error { + return ExportSpec(w, format, builder, groups) + }) +} + +// ExportSpecToFileIter writes the OpenAPI spec from an iterator to the given path. +// The parent directory is created if it does not exist. +// +// Example: +// +// _ = api.ExportSpecToFileIter("./api/openapi.json", "json", builder, api.RegisteredSpecGroupsIter()) +func ExportSpecToFileIter(path, format string, builder *SpecBuilder, groups iter.Seq[RouteGroup]) error { + return exportSpecToFile(path, "ExportSpecToFileIter", func(w io.Writer) error { + return ExportSpecIter(w, format, builder, groups) + }) +} + +func exportSpecToFile(path, op string, write func(io.Writer) error) (err error) { + dir := filepath.Dir(path) + if err := coreio.Local.EnsureDir(dir); err != nil { + return coreerr.E(op, "create directory", err) } - f, err := os.Create(path) + + // Write to a temp file in the same directory so the rename is atomic on + // most filesystems. The destination is never truncated unless the full + // export succeeds. + f, err := os.CreateTemp(dir, ".export-*.tmp") if err != nil { - return coreerr.E("ExportSpecToFile", "create file", err) + return coreerr.E(op, "create temp file", err) + } + tmpPath := f.Name() + + defer func() { + if err != nil { + _ = os.Remove(tmpPath) + } + }() + + if writeErr := write(f); writeErr != nil { + _ = f.Close() + return writeErr + } + + if closeErr := f.Close(); closeErr != nil { + return coreerr.E(op, "close temp file", closeErr) + } + + if renameErr := os.Rename(tmpPath, path); renameErr != nil { + return coreerr.E(op, "rename temp file", renameErr) } - defer f.Close() - return ExportSpec(f, format, builder, groups) + return nil } diff --git a/export_test.go b/export_test.go index 1a26e33..926b5c3 100644 --- a/export_test.go +++ b/export_test.go @@ -5,6 +5,7 @@ package api_test import ( "bytes" "encoding/json" + "iter" "net/http" "os" "path/filepath" @@ -65,6 +66,24 @@ func TestExportSpec_Good_YAML(t *testing.T) { } } +func TestExportSpec_Good_NormalisesFormatInput(t *testing.T) { + builder := &api.SpecBuilder{Title: "Test", Description: "Test API", Version: "1.0.0"} + + var buf bytes.Buffer + if err := api.ExportSpec(&buf, " YAML ", builder, nil); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := yaml.Unmarshal(buf.Bytes(), &spec); err != nil { + t.Fatalf("output is not valid YAML: %v", err) + } + + if spec["openapi"] != "3.1.0" { + t.Fatalf("expected openapi=3.1.0, got %v", spec["openapi"]) + } +} + func TestExportSpec_Bad_InvalidFormat(t *testing.T) { builder := &api.SpecBuilder{Title: "Test", Description: "Test API", Version: "1.0.0"} @@ -164,3 +183,41 @@ func TestExportSpec_Good_WithToolBridge(t *testing.T) { t.Fatal("expected /tools/metrics_query path in spec") } } + +func TestExportSpecIter_Good_WithGroupIterator(t *testing.T) { + builder := &api.SpecBuilder{Title: "Test", Description: "Test API", Version: "1.0.0"} + + group := &specStubGroup{ + name: "iter", + basePath: "/iter", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/ping", + Summary: "Ping iter group", + Response: map[string]any{ + "type": "string", + }, + }, + }, + } + + groups := iter.Seq[api.RouteGroup](func(yield func(api.RouteGroup) bool) { + _ = yield(group) + }) + + var buf bytes.Buffer + if err := api.ExportSpecIter(&buf, "json", builder, groups); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(buf.Bytes(), &spec); err != nil { + t.Fatalf("output is not valid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + if _, ok := paths["/iter/ping"]; !ok { + t.Fatal("expected /iter/ping path in spec") + } +} diff --git a/go-io/go.mod b/go-io/go.mod new file mode 100644 index 0000000..af101a6 --- /dev/null +++ b/go-io/go.mod @@ -0,0 +1,3 @@ +module dappco.re/go/core/io + +go 1.26.0 diff --git a/go-io/local.go b/go-io/local.go new file mode 100644 index 0000000..5bfb15d --- /dev/null +++ b/go-io/local.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package io + +import "os" + +// LocalFS provides simple local filesystem helpers used by the API module. +var Local localFS + +type localFS struct{} + +// EnsureDir creates the directory path if it does not already exist. +func (localFS) EnsureDir(path string) error { + if path == "" || path == "." { + return nil + } + return os.MkdirAll(path, 0o755) +} + +// Delete removes the named file, ignoring missing files. +func (localFS) Delete(path string) error { + if path == "" { + return nil + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/go-log/error.go b/go-log/error.go new file mode 100644 index 0000000..939c8bf --- /dev/null +++ b/go-log/error.go @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package log + +import "fmt" + +// E wraps an operation label and message in a conventional error. +// If err is non-nil, it is wrapped with %w. +func E(op, message string, err error) error { + if err != nil { + return fmt.Errorf("%s: %s: %w", op, message, err) + } + return fmt.Errorf("%s: %s", op, message) +} diff --git a/go-log/go.mod b/go-log/go.mod new file mode 100644 index 0000000..c513da7 --- /dev/null +++ b/go-log/go.mod @@ -0,0 +1,3 @@ +module dappco.re/go/core/log + +go 1.26.0 diff --git a/go.mod b/go.mod index 50f8b1d..a66b7dc 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.26.0 require ( dappco.re/go/core/io v0.1.7 dappco.re/go/core/log v0.0.4 - forge.lthn.ai/core/cli v0.3.7 + dappco.re/go/core/cli v0.3.7 github.com/99designs/gqlgen v0.17.88 github.com/andybalholm/brotli v1.2.0 github.com/casbin/casbin/v2 v2.135.0 @@ -38,10 +38,10 @@ require ( ) require ( - forge.lthn.ai/core/go v0.3.2 // indirect - forge.lthn.ai/core/go-i18n v0.1.7 // indirect - forge.lthn.ai/core/go-inference v0.1.7 // indirect - forge.lthn.ai/core/go-log v0.0.4 // indirect + dappco.re/go/core v0.3.2 // indirect + dappco.re/go/core/i18n v0.1.7 // indirect + dappco.re/go/core/inference v0.1.7 // indirect + dappco.re/go/core/log v0.0.4 // indirect github.com/KyleBanks/depth v1.2.1 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect @@ -132,6 +132,6 @@ require ( replace ( dappco.re/go/core => ../go dappco.re/go/core/i18n => ../go-i18n - dappco.re/go/core/io => ../go-io - dappco.re/go/core/log => ../go-log + dappco.re/go/core/io => ./go-io + dappco.re/go/core/log => ./go-log ) diff --git a/graphql.go b/graphql.go index c878ee3..0a3298c 100644 --- a/graphql.go +++ b/graphql.go @@ -4,6 +4,7 @@ package api import ( "net/http" + "strings" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler" @@ -21,10 +22,61 @@ type graphqlConfig struct { playground bool } +// GraphQLConfig captures the configured GraphQL endpoint settings for an Engine. +// +// It is intentionally small and serialisable so callers can inspect the active +// GraphQL surface without reaching into the internal handler configuration. +// +// Example: +// +// cfg := api.GraphQLConfig{Enabled: true, Path: "/graphql", Playground: true} +type GraphQLConfig struct { + Enabled bool + Path string + Playground bool + PlaygroundPath string +} + +// GraphQLConfig returns the currently configured GraphQL settings for the engine. +// +// The result snapshots the Engine state at call time and normalises any configured +// URL path using the same rules as the runtime handlers. +// +// Example: +// +// cfg := engine.GraphQLConfig() +func (e *Engine) GraphQLConfig() GraphQLConfig { + if e == nil { + return GraphQLConfig{} + } + + cfg := GraphQLConfig{ + Enabled: e.graphql != nil, + Playground: e.graphql != nil && e.graphql.playground, + } + + if e.graphql != nil { + cfg.Path = normaliseGraphQLPath(e.graphql.path) + if e.graphql.playground { + cfg.PlaygroundPath = cfg.Path + "/playground" + } + } + + return cfg +} + // GraphQLOption configures a GraphQL endpoint. +// +// Example: +// +// opts := []api.GraphQLOption{api.WithPlayground(), api.WithGraphQLPath("/gql")} type GraphQLOption func(*graphqlConfig) // WithPlayground enables the GraphQL Playground UI at {path}/playground. +// +// Example: +// +// api.WithGraphQL(schema, api.WithPlayground()) func WithPlayground() GraphQLOption { return func(cfg *graphqlConfig) { cfg.playground = true @@ -33,9 +85,13 @@ func WithPlayground() GraphQLOption { // WithGraphQLPath sets a custom URL path for the GraphQL endpoint. // The default path is "/graphql". +// +// Example: +// +// api.WithGraphQL(schema, api.WithGraphQLPath("/gql")) func WithGraphQLPath(path string) GraphQLOption { return func(cfg *graphqlConfig) { - cfg.path = path + cfg.path = normaliseGraphQLPath(path) } } @@ -55,6 +111,22 @@ func mountGraphQL(r *gin.Engine, cfg *graphqlConfig) { } } +// normaliseGraphQLPath coerces custom GraphQL paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseGraphQLPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultGraphQLPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultGraphQLPath + } + + return path +} + // wrapHTTPHandler adapts a standard http.Handler to a Gin handler function. func wrapHTTPHandler(h http.Handler) gin.HandlerFunc { return func(c *gin.Context) { diff --git a/graphql_config_test.go b/graphql_config_test.go new file mode 100644 index 0000000..83bbc8d --- /dev/null +++ b/graphql_config_test.go @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "testing" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/core/api" +) + +func TestEngine_GraphQLConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath(" /gql/ ")), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.GraphQLConfig() + if !cfg.Enabled { + t.Fatal("expected GraphQL to be enabled") + } + if cfg.Path != "/gql" { + t.Fatalf("expected GraphQL path /gql, got %q", cfg.Path) + } + if !cfg.Playground { + t.Fatal("expected GraphQL playground to be enabled") + } + if cfg.PlaygroundPath != "/gql/playground" { + t.Fatalf("expected GraphQL playground path /gql/playground, got %q", cfg.PlaygroundPath) + } +} + +func TestEngine_GraphQLConfig_Good_EmptyOnNilEngine(t *testing.T) { + var e *api.Engine + + cfg := e.GraphQLConfig() + if cfg.Enabled || cfg.Path != "" || cfg.Playground || cfg.PlaygroundPath != "" { + t.Fatalf("expected zero-value GraphQL config, got %+v", cfg) + } +} diff --git a/graphql_test.go b/graphql_test.go index e201858..47c6dce 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -192,6 +192,72 @@ func TestWithGraphQL_Good_CustomPath(t *testing.T) { } } +func TestWithGraphQL_Good_NormalisesCustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithGraphQL(newTestSchema(), api.WithGraphQLPath(" /gql/ "), api.WithPlayground())) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + body := `{"query":"{ name }"}` + resp, err := http.Post(srv.URL+"/gql", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 at normalised /gql, got %d", resp.StatusCode) + } + + pgResp, err := http.Get(srv.URL + "/gql/playground") + if err != nil { + t.Fatalf("playground request failed: %v", err) + } + defer pgResp.Body.Close() + + if pgResp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 at normalised /gql/playground, got %d", pgResp.StatusCode) + } +} + +func TestWithGraphQL_Good_DefaultPathWhenEmptyCustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithGraphQL(newTestSchema(), api.WithGraphQLPath(""), api.WithPlayground())) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + body := `{"query":"{ name }"}` + resp, err := http.Post(srv.URL+"/graphql", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 at default /graphql, got %d", resp.StatusCode) + } + + pgResp, err := http.Get(srv.URL + "/graphql/playground") + if err != nil { + t.Fatalf("playground request failed: %v", err) + } + defer pgResp.Body.Close() + + if pgResp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 at default /graphql/playground, got %d", pgResp.StatusCode) + } +} + func TestWithGraphQL_Good_CombinesWithOtherMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/group.go b/group.go index 46d0cf4..7ee8798 100644 --- a/group.go +++ b/group.go @@ -2,10 +2,18 @@ package api -import "github.com/gin-gonic/gin" +import ( + "iter" + + "github.com/gin-gonic/gin" +) // RouteGroup registers API routes onto a Gin router group. // Subsystems implement this interface to declare their endpoints. +// +// Example: +// +// var g api.RouteGroup = &myGroup{} type RouteGroup interface { // Name returns a human-readable identifier for the group. Name() string @@ -18,6 +26,10 @@ type RouteGroup interface { } // StreamGroup optionally declares WebSocket channels a subsystem publishes to. +// +// Example: +// +// var sg api.StreamGroup = &myStreamGroup{} type StreamGroup interface { // Channels returns the list of channel names this group streams on. Channels() []string @@ -26,19 +38,89 @@ type StreamGroup interface { // DescribableGroup extends RouteGroup with OpenAPI metadata. // RouteGroups that implement this will have their endpoints // included in the generated OpenAPI specification. +// +// Example: +// +// var dg api.DescribableGroup = &myDescribableGroup{} type DescribableGroup interface { RouteGroup // Describe returns endpoint descriptions for OpenAPI generation. Describe() []RouteDescription } +// DescribableGroupIter extends DescribableGroup with an iterator-based +// description source for callers that want to avoid slice allocation. +// +// Example: +// +// var dg api.DescribableGroupIter = &myDescribableGroup{} +type DescribableGroupIter interface { + DescribableGroup + // DescribeIter returns endpoint descriptions for OpenAPI generation. + DescribeIter() iter.Seq[RouteDescription] +} + // RouteDescription describes a single endpoint for OpenAPI generation. +// +// Example: +// +// rd := api.RouteDescription{ +// Method: "POST", +// Path: "/users", +// Summary: "Create a user", +// Description: "Creates a new user account.", +// Tags: []string{"users"}, +// StatusCode: 201, +// RequestBody: map[string]any{"type": "object"}, +// Response: map[string]any{"type": "object"}, +// } type RouteDescription struct { - Method string // HTTP method: GET, POST, PUT, DELETE, PATCH - Path string // Path relative to BasePath, e.g. "/generate" - Summary string // Short summary - Description string // Long description - Tags []string // OpenAPI tags for grouping - RequestBody map[string]any // JSON Schema for request body (nil for GET) - Response map[string]any // JSON Schema for success response data + Method string // HTTP method: GET, POST, PUT, DELETE, PATCH + Path string // Path relative to BasePath, e.g. "/generate" + Summary string // Short summary + Description string // Long description + Tags []string // OpenAPI tags for grouping + // Hidden omits the route from generated documentation. + Hidden bool + // Deprecated marks the operation as deprecated in OpenAPI. + Deprecated bool + // SunsetDate marks when a deprecated operation will be removed. + // Use YYYY-MM-DD or an RFC 7231 HTTP date string. + SunsetDate string + // Replacement points to the successor endpoint URL, when known. + Replacement string + // StatusCode is the documented 2xx success status code. + // Zero defaults to 200. + StatusCode int + // Security overrides the default bearerAuth requirement when non-nil. + // Use an empty, non-nil slice to mark the route as public. + Security []map[string][]string + Parameters []ParameterDescription + RequestBody map[string]any // JSON Schema for request body (nil for GET) + RequestExample any // Optional example payload for the request body. + Response map[string]any // JSON Schema for success response data + ResponseExample any // Optional example payload for the success response. + ResponseHeaders map[string]string +} + +// ParameterDescription describes an OpenAPI parameter for a route. +// +// Example: +// +// param := api.ParameterDescription{ +// Name: "id", +// In: "path", +// Description: "User identifier", +// Required: true, +// Schema: map[string]any{"type": "string"}, +// Example: "usr_123", +// } +type ParameterDescription struct { + Name string // Parameter name. + In string // Parameter location: path, query, header, or cookie. + Description string // Human-readable parameter description. + Required bool // Whether the parameter is required. + Deprecated bool // Whether the parameter is deprecated. + Schema map[string]any // JSON Schema for the parameter value. + Example any // Optional example value. } diff --git a/i18n.go b/i18n.go index a9b5974..03ddb3b 100644 --- a/i18n.go +++ b/i18n.go @@ -3,6 +3,9 @@ package api import ( + "slices" + "strings" + "github.com/gin-gonic/gin" "golang.org/x/text/language" ) @@ -13,7 +16,21 @@ const i18nContextKey = "i18n.locale" // i18nMessagesKey is the Gin context key for the message lookup map. const i18nMessagesKey = "i18n.messages" +// i18nCatalogKey is the Gin context key for the full locale->message catalog. +const i18nCatalogKey = "i18n.catalog" + +// i18nDefaultLocaleKey stores the configured default locale for fallback lookups. +const i18nDefaultLocaleKey = "i18n.default_locale" + // I18nConfig configures the internationalisation middleware. +// +// Example: +// +// cfg := api.I18nConfig{ +// DefaultLocale: "en", +// Supported: []string{"en", "fr"}, +// Messages: map[string]map[string]string{"fr": {"greeting": "Bonjour"}}, +// } type I18nConfig struct { // DefaultLocale is the fallback locale when the Accept-Language header // is absent or does not match any supported locale. Defaults to "en". @@ -30,11 +47,32 @@ type I18nConfig struct { Messages map[string]map[string]string } +// I18nConfig returns the configured locale and message catalogue settings for +// the engine. +// +// The result snapshots the Engine state at call time and clones slices/maps so +// callers can safely reuse or modify the returned value. +// +// Example: +// +// cfg := engine.I18nConfig() +func (e *Engine) I18nConfig() I18nConfig { + if e == nil { + return I18nConfig{} + } + + return cloneI18nConfig(e.i18nConfig) +} + // WithI18n adds Accept-Language header parsing and locale detection middleware. // The middleware uses golang.org/x/text/language for RFC 5646 language matching // with quality weighting support. The detected locale is stored in the Gin // context and can be retrieved by handlers via GetLocale(). // +// Example: +// +// api.New(api.WithI18n(api.I18nConfig{Supported: []string{"en", "fr"}})) +// // If messages are configured, handlers can look up localised strings via // GetMessage(). This is a lightweight bridge — the go-i18n grammar engine // can replace the message map later. @@ -57,14 +95,16 @@ func WithI18n(cfg ...I18nConfig) Option { tags = append(tags, tag) } } + snapshot := cloneI18nConfig(config) + e.i18nConfig = snapshot matcher := language.NewMatcher(tags) - e.middlewares = append(e.middlewares, i18nMiddleware(matcher, config)) + e.middlewares = append(e.middlewares, i18nMiddleware(matcher, snapshot)) } } // i18nMiddleware returns Gin middleware that parses Accept-Language, matches -// it against supported locales, and stores the result in the context. +// it against supported locales, and stores the resolved BCP 47 tag in the context. func i18nMiddleware(matcher language.Matcher, cfg I18nConfig) gin.HandlerFunc { return func(c *gin.Context) { accept := c.GetHeader("Accept-Language") @@ -75,19 +115,17 @@ func i18nMiddleware(matcher language.Matcher, cfg I18nConfig) gin.HandlerFunc { } else { tags, _, _ := language.ParseAcceptLanguage(accept) tag, _, _ := matcher.Match(tags...) - base, _ := tag.Base() - locale = base.String() + locale = tag.String() } c.Set(i18nContextKey, locale) + c.Set(i18nDefaultLocaleKey, cfg.DefaultLocale) // Attach the message map for this locale if messages are configured. if cfg.Messages != nil { + c.Set(i18nCatalogKey, cfg.Messages) if msgs, ok := cfg.Messages[locale]; ok { c.Set(i18nMessagesKey, msgs) - } else if msgs, ok := cfg.Messages[cfg.DefaultLocale]; ok { - // Fall back to default locale messages. - c.Set(i18nMessagesKey, msgs) } } @@ -97,6 +135,10 @@ func i18nMiddleware(matcher language.Matcher, cfg I18nConfig) gin.HandlerFunc { // GetLocale returns the detected locale for the current request. // Returns "en" if the i18n middleware was not applied. +// +// Example: +// +// locale := api.GetLocale(c) func GetLocale(c *gin.Context) string { if v, ok := c.Get(i18nContextKey); ok { if s, ok := v.(string); ok { @@ -109,6 +151,10 @@ func GetLocale(c *gin.Context) string { // GetMessage looks up a localised message by key for the current request. // Returns the message string and true if found, or empty string and false // if the key does not exist or the i18n middleware was not applied. +// +// Example: +// +// msg, ok := api.GetMessage(c, "greeting") func GetMessage(c *gin.Context, key string) (string, bool) { if v, ok := c.Get(i18nMessagesKey); ok { if msgs, ok := v.(map[string]string); ok { @@ -117,5 +163,84 @@ func GetMessage(c *gin.Context, key string) (string, bool) { } } } + + catalog, _ := c.Get(i18nCatalogKey) + msgsByLocale, _ := catalog.(map[string]map[string]string) + if len(msgsByLocale) == 0 { + return "", false + } + + locales := localeFallbacks(GetLocale(c)) + if defaultLocale, ok := c.Get(i18nDefaultLocaleKey); ok { + if fallback, ok := defaultLocale.(string); ok && fallback != "" { + locales = append(locales, localeFallbacks(fallback)...) + } + } + + seen := make(map[string]struct{}, len(locales)) + for _, locale := range locales { + if locale == "" { + continue + } + if _, ok := seen[locale]; ok { + continue + } + seen[locale] = struct{}{} + if msgs, ok := msgsByLocale[locale]; ok { + if msg, ok := msgs[key]; ok { + return msg, true + } + } + } + return "", false } + +// localeFallbacks returns the locale and its parent tags in order from +// most specific to least specific. For example, "fr-CA" yields +// ["fr-CA", "fr"] and "zh-Hant-TW" yields ["zh-Hant-TW", "zh-Hant", "zh"]. +func localeFallbacks(locale string) []string { + locale = strings.TrimSpace(strings.ReplaceAll(locale, "_", "-")) + if locale == "" { + return nil + } + + parts := strings.Split(locale, "-") + if len(parts) == 0 { + return []string{locale} + } + + fallbacks := make([]string, 0, len(parts)) + for i := len(parts); i >= 1; i-- { + fallbacks = append(fallbacks, strings.Join(parts[:i], "-")) + } + + return fallbacks +} + +func cloneI18nConfig(cfg I18nConfig) I18nConfig { + out := cfg + out.Supported = slices.Clone(cfg.Supported) + out.Messages = cloneI18nMessages(cfg.Messages) + return out +} + +func cloneI18nMessages(messages map[string]map[string]string) map[string]map[string]string { + if len(messages) == 0 { + return nil + } + + out := make(map[string]map[string]string, len(messages)) + for locale, msgs := range messages { + if len(msgs) == 0 { + out[locale] = nil + continue + } + cloned := make(map[string]string, len(msgs)) + for key, value := range msgs { + cloned[key] = value + } + out[locale] = cloned + } + return out +} diff --git a/i18n_test.go b/i18n_test.go index 66189e7..b56e5cf 100644 --- a/i18n_test.go +++ b/i18n_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "slices" "testing" "github.com/gin-gonic/gin" @@ -133,6 +134,33 @@ func TestWithI18n_Good_QualityWeighting(t *testing.T) { } } +func TestWithI18n_Good_PreservesMatchedLocaleTag(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithI18n(api.I18nConfig{ + DefaultLocale: "en", + Supported: []string{"en", "fr", "fr-CA"}, + })) + e.Register(&i18nTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/i18n/locale", nil) + req.Header.Set("Accept-Language", "fr-CA, fr;q=0.8") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp i18nLocaleResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data["locale"] != "fr-CA" { + t.Fatalf("expected locale=%q, got %q", "fr-CA", resp.Data["locale"]) + } +} + func TestWithI18n_Good_CombinesWithOtherMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) e, _ := api.New( @@ -224,3 +252,122 @@ func TestWithI18n_Good_LooksUpMessage(t *testing.T) { t.Fatalf("expected message=%q, got %q", "Hello", respEn.Data.Message) } } + +func TestWithI18n_Good_FallsBackToParentLocaleMessage(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithI18n(api.I18nConfig{ + DefaultLocale: "en", + Supported: []string{"en", "fr", "fr-CA"}, + Messages: map[string]map[string]string{ + "en": {"greeting": "Hello"}, + "fr": {"greeting": "Bonjour"}, + }, + })) + e.Register(&i18nTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/i18n/greeting", nil) + req.Header.Set("Accept-Language", "fr-CA") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp i18nMessageResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data.Locale != "fr-CA" { + t.Fatalf("expected locale=%q, got %q", "fr-CA", resp.Data.Locale) + } + if resp.Data.Message != "Bonjour" { + t.Fatalf("expected fallback message=%q, got %q", "Bonjour", resp.Data.Message) + } + if !resp.Data.Found { + t.Fatal("expected found=true") + } +} + +func TestEngine_I18nConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + e, _ := api.New(api.WithI18n(api.I18nConfig{ + DefaultLocale: "en", + Supported: []string{"en", "fr"}, + Messages: map[string]map[string]string{ + "en": {"greeting": "Hello"}, + "fr": {"greeting": "Bonjour"}, + }, + })) + + snap := e.I18nConfig() + if snap.DefaultLocale != "en" { + t.Fatalf("expected default locale en, got %q", snap.DefaultLocale) + } + if !slices.Equal(snap.Supported, []string{"en", "fr"}) { + t.Fatalf("expected supported locales [en fr], got %v", snap.Supported) + } + if snap.Messages["fr"]["greeting"] != "Bonjour" { + t.Fatalf("expected cloned French greeting, got %q", snap.Messages["fr"]["greeting"]) + } +} + +func TestEngine_I18nConfig_Good_ClonesMutableInputs(t *testing.T) { + supported := []string{"en", "fr"} + messages := map[string]map[string]string{ + "en": {"greeting": "Hello"}, + "fr": {"greeting": "Bonjour"}, + } + + e, _ := api.New(api.WithI18n(api.I18nConfig{ + DefaultLocale: "en", + Supported: supported, + Messages: messages, + })) + + supported[0] = "de" + messages["fr"]["greeting"] = "Salut" + + snap := e.I18nConfig() + if !slices.Equal(snap.Supported, []string{"en", "fr"}) { + t.Fatalf("expected engine supported locales to be cloned, got %v", snap.Supported) + } + if snap.Messages["fr"]["greeting"] != "Bonjour" { + t.Fatalf("expected engine message catalogue to be cloned, got %q", snap.Messages["fr"]["greeting"]) + } +} + +func TestWithI18n_Good_SnapshotsMutableInputs(t *testing.T) { + gin.SetMode(gin.TestMode) + messages := map[string]map[string]string{ + "en": {"greeting": "Hello"}, + "fr": {"greeting": "Bonjour"}, + } + + e, _ := api.New(api.WithI18n(api.I18nConfig{ + DefaultLocale: "en", + Supported: []string{"en", "fr"}, + Messages: messages, + })) + e.Register(&i18nTestGroup{}) + + messages["fr"]["greeting"] = "Salut" + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/i18n/greeting", nil) + req.Header.Set("Accept-Language", "fr") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp i18nMessageResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data.Message != "Bonjour" { + t.Fatalf("expected cloned greeting %q, got %q", "Bonjour", resp.Data.Message) + } +} diff --git a/middleware.go b/middleware.go index 55fe8ae..35e9120 100644 --- a/middleware.go +++ b/middleware.go @@ -5,20 +5,44 @@ package api import ( "crypto/rand" "encoding/hex" + "fmt" "net/http" + "runtime/debug" "strings" + "time" "github.com/gin-gonic/gin" ) +// requestIDContextKey is the Gin context key used by requestIDMiddleware. +const requestIDContextKey = "request_id" + +// requestStartContextKey stores when the request began so handlers can +// calculate elapsed duration for response metadata. +const requestStartContextKey = "request_start" + +// recoveryMiddleware converts panics into a standard JSON error envelope. +// This keeps internal failures consistent with the rest of the framework +// and avoids Gin's default plain-text 500 response. +func recoveryMiddleware() gin.HandlerFunc { + return gin.CustomRecovery(func(c *gin.Context, recovered any) { + fmt.Fprintf(gin.DefaultErrorWriter, "[Recovery] panic recovered: %v\n", recovered) + debug.PrintStack() + c.AbortWithStatusJSON(http.StatusInternalServerError, Fail( + "internal_server_error", + "Internal server error", + )) + }) +} + // bearerAuthMiddleware validates the Authorization: Bearer header. // Requests to paths in the skip list are allowed through without authentication. // Returns 401 with Fail("unauthorised", ...) on missing or invalid tokens. -func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc { +func bearerAuthMiddleware(token string, skip func() []string) gin.HandlerFunc { return func(c *gin.Context) { // Check whether the request path should bypass authentication. - for _, path := range skip { - if strings.HasPrefix(c.Request.URL.Path, path) { + for _, path := range skip() { + if isPublicPath(c.Request.URL.Path, path) { c.Next() return } @@ -40,11 +64,37 @@ func bearerAuthMiddleware(token string, skip []string) gin.HandlerFunc { } } +// isPublicPath reports whether requestPath should bypass auth for publicPath. +// It matches the exact path and any nested subpath, but not sibling prefixes +// such as /swaggerx when the public path is /swagger. +func isPublicPath(requestPath, publicPath string) bool { + if publicPath == "" { + return false + } + + normalized := strings.TrimRight(publicPath, "/") + if normalized == "" { + normalized = "/" + } + + if requestPath == normalized { + return true + } + + if normalized == "/" { + return true + } + + return strings.HasPrefix(requestPath, normalized+"/") +} + // requestIDMiddleware ensures every response carries an X-Request-ID header. // If the client sends one, it is preserved; otherwise a random 16-byte hex // string is generated. The ID is also stored in the Gin context as "request_id". func requestIDMiddleware() gin.HandlerFunc { return func(c *gin.Context) { + c.Set(requestStartContextKey, time.Now()) + id := c.GetHeader("X-Request-ID") if id == "" { b := make([]byte, 16) @@ -52,8 +102,63 @@ func requestIDMiddleware() gin.HandlerFunc { id = hex.EncodeToString(b) } - c.Set("request_id", id) + c.Set(requestIDContextKey, id) c.Header("X-Request-ID", id) c.Next() } } + +// GetRequestID returns the request ID assigned by requestIDMiddleware. +// Returns an empty string when the middleware was not applied. +// +// Example: +// +// id := api.GetRequestID(c) +func GetRequestID(c *gin.Context) string { + if v, ok := c.Get(requestIDContextKey); ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +// GetRequestDuration returns the elapsed time since requestIDMiddleware started +// handling the request. Returns 0 when the middleware was not applied. +// +// Example: +// +// d := api.GetRequestDuration(c) +func GetRequestDuration(c *gin.Context) time.Duration { + if v, ok := c.Get(requestStartContextKey); ok { + if started, ok := v.(time.Time); ok && !started.IsZero() { + return time.Since(started) + } + } + return 0 +} + +// GetRequestMeta returns request metadata collected by requestIDMiddleware. +// The returned meta includes the request ID and elapsed duration when +// available. It returns nil when neither value is available. +// +// Example: +// +// meta := api.GetRequestMeta(c) +func GetRequestMeta(c *gin.Context) *Meta { + meta := &Meta{} + + if id := GetRequestID(c); id != "" { + meta.RequestID = id + } + + if duration := GetRequestDuration(c); duration > 0 { + meta.Duration = duration.String() + } + + if meta.RequestID == "" && meta.Duration == "" { + return nil + } + + return meta +} diff --git a/middleware_test.go b/middleware_test.go index a44da53..58d9c43 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" @@ -26,6 +27,75 @@ func (m *mwTestGroup) RegisterRoutes(rg *gin.RouterGroup) { }) } +type swaggerLikeGroup struct{} + +func (g *swaggerLikeGroup) Name() string { return "swagger-like" } +func (g *swaggerLikeGroup) BasePath() string { return "/swaggerx" } +func (g *swaggerLikeGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/secret", func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("classified")) + }) +} + +type requestIDTestGroup struct { + gotID *string +} + +func (g requestIDTestGroup) Name() string { return "request-id" } +func (g requestIDTestGroup) BasePath() string { return "/v1" } +func (g requestIDTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/secret", func(c *gin.Context) { + *g.gotID = api.GetRequestID(c) + c.JSON(http.StatusOK, api.OK("classified")) + }) +} + +type requestMetaTestGroup struct{} + +func (g requestMetaTestGroup) Name() string { return "request-meta" } +func (g requestMetaTestGroup) BasePath() string { return "/v1" } +func (g requestMetaTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/meta", func(c *gin.Context) { + time.Sleep(2 * time.Millisecond) + resp := api.AttachRequestMeta(c, api.Paginated("classified", 1, 25, 100)) + c.JSON(http.StatusOK, resp) + }) +} + +type autoResponseMetaTestGroup struct{} + +func (g autoResponseMetaTestGroup) Name() string { return "auto-response-meta" } +func (g autoResponseMetaTestGroup) BasePath() string { return "/v1" } +func (g autoResponseMetaTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/meta", func(c *gin.Context) { + time.Sleep(2 * time.Millisecond) + c.JSON(http.StatusOK, api.Paginated("classified", 1, 25, 100)) + }) +} + +type autoErrorResponseMetaTestGroup struct{} + +func (g autoErrorResponseMetaTestGroup) Name() string { return "auto-error-response-meta" } +func (g autoErrorResponseMetaTestGroup) BasePath() string { return "/v1" } +func (g autoErrorResponseMetaTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/error", func(c *gin.Context) { + time.Sleep(2 * time.Millisecond) + c.JSON(http.StatusBadRequest, api.Fail("bad_request", "request failed")) + }) +} + +type plusJSONResponseMetaTestGroup struct{} + +func (g plusJSONResponseMetaTestGroup) Name() string { return "plus-json-response-meta" } +func (g plusJSONResponseMetaTestGroup) BasePath() string { return "/v1" } +func (g plusJSONResponseMetaTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/plus-json", func(c *gin.Context) { + c.Header("Content-Type", "application/problem+json") + c.Status(http.StatusOK) + _, _ = c.Writer.Write([]byte(`{"success":true,"data":"ok"}`)) + }) +} + // ── Bearer auth ───────────────────────────────────────────────────────── func TestBearerAuth_Bad_MissingToken(t *testing.T) { @@ -114,6 +184,21 @@ func TestBearerAuth_Good_HealthBypassesAuth(t *testing.T) { } } +func TestBearerAuth_Bad_SimilarPrefixDoesNotBypassAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithBearerAuth("s3cret")) + e.Register(&swaggerLikeGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/swaggerx/secret", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 for /swaggerx/secret, got %d", w.Code) + } +} + // ── Request ID ────────────────────────────────────────────────────────── func TestRequestID_Good_GeneratedWhenMissing(t *testing.T) { @@ -151,6 +236,176 @@ func TestRequestID_Good_PreservesClientID(t *testing.T) { } } +func TestRequestID_Good_ContextAccessor(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + + var gotID string + e.Register(requestIDTestGroup{gotID: &gotID}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/secret", nil) + req.Header.Set("X-Request-ID", "client-id-xyz") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + if gotID == "" { + t.Fatal("expected GetRequestID to return the request ID inside the handler") + } + if gotID != "client-id-xyz" { + t.Fatalf("expected GetRequestID=%q, got %q", "client-id-xyz", gotID) + } +} + +func TestRequestID_Good_RequestMetaHelper(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRequestID()) + e.Register(requestMetaTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/meta", nil) + req.Header.Set("X-Request-ID", "client-id-meta") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Meta == nil { + t.Fatal("expected Meta to be present") + } + if resp.Meta.RequestID != "client-id-meta" { + t.Fatalf("expected request_id=%q, got %q", "client-id-meta", resp.Meta.RequestID) + } + if resp.Meta.Duration == "" { + t.Fatal("expected duration to be populated") + } + if resp.Meta.Page != 1 || resp.Meta.PerPage != 25 || resp.Meta.Total != 100 { + t.Fatalf("expected pagination metadata to be preserved, got %+v", resp.Meta) + } +} + +func TestResponseMeta_Good_AttachesMetaAutomatically(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + ) + e.Register(autoResponseMetaTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/meta", nil) + req.Header.Set("X-Request-ID", "client-id-auto-meta") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Meta == nil { + t.Fatal("expected Meta to be present") + } + if resp.Meta.RequestID != "client-id-auto-meta" { + t.Fatalf("expected request_id=%q, got %q", "client-id-auto-meta", resp.Meta.RequestID) + } + if resp.Meta.Duration == "" { + t.Fatal("expected duration to be populated") + } + if resp.Meta.Page != 1 || resp.Meta.PerPage != 25 || resp.Meta.Total != 100 { + t.Fatalf("expected pagination metadata to be preserved, got %+v", resp.Meta) + } + if got := w.Header().Get("X-Request-ID"); got != "client-id-auto-meta" { + t.Fatalf("expected response header X-Request-ID=%q, got %q", "client-id-auto-meta", got) + } +} + +func TestResponseMeta_Good_AttachesMetaToErrorResponses(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + ) + e.Register(autoErrorResponseMetaTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/error", nil) + req.Header.Set("X-Request-ID", "client-id-auto-error-meta") + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Meta == nil { + t.Fatal("expected Meta to be present") + } + if resp.Meta.RequestID != "client-id-auto-error-meta" { + t.Fatalf("expected request_id=%q, got %q", "client-id-auto-error-meta", resp.Meta.RequestID) + } + if resp.Meta.Duration == "" { + t.Fatal("expected duration to be populated") + } + if resp.Error == nil || resp.Error.Code != "bad_request" { + t.Fatalf("expected bad_request error, got %+v", resp.Error) + } +} + +func TestResponseMeta_Good_AttachesMetaToPlusJSONContentType(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + ) + e.Register(plusJSONResponseMetaTestGroup{}) + + h := e.Handler() + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/v1/plus-json", nil) + req.Header.Set("X-Request-ID", "client-id-plus-json-meta") + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + if got := w.Header().Get("Content-Type"); got != "application/problem+json" { + t.Fatalf("expected Content-Type to be preserved, got %q", got) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Meta == nil { + t.Fatal("expected Meta to be present") + } + if resp.Meta.RequestID != "client-id-plus-json-meta" { + t.Fatalf("expected request_id=%q, got %q", "client-id-plus-json-meta", resp.Meta.RequestID) + } + if resp.Meta.Duration == "" { + t.Fatal("expected duration to be populated") + } +} + // ── CORS ──────────────────────────────────────────────────────────────── func TestCORS_Good_PreflightAllOrigins(t *testing.T) { diff --git a/modernization_test.go b/modernization_test.go index 21d08a2..b6e3a7d 100644 --- a/modernization_test.go +++ b/modernization_test.go @@ -5,6 +5,7 @@ package api_test import ( "slices" "testing" + "time" api "dappco.re/go/core/api" ) @@ -27,6 +28,28 @@ func TestEngine_GroupsIter(t *testing.T) { } } +func TestEngine_GroupsIter_Good_SnapshotsCurrentGroups(t *testing.T) { + e, _ := api.New() + g1 := &healthGroup{} + g2 := &stubGroup{} + e.Register(g1) + + iter := e.GroupsIter() + e.Register(g2) + + var groups []api.RouteGroup + for g := range iter { + groups = append(groups, g) + } + + if len(groups) != 1 { + t.Fatalf("expected iterator snapshot to contain 1 group, got %d", len(groups)) + } + if groups[0].Name() != "health-extra" { + t.Fatalf("expected snapshot to preserve original group, got %q", groups[0].Name()) + } +} + type streamGroupStub struct { healthGroup channels []string @@ -52,6 +75,207 @@ func TestEngine_ChannelsIter(t *testing.T) { } } +func TestEngine_ChannelsIter_Good_SnapshotsCurrentChannels(t *testing.T) { + e, _ := api.New() + g1 := &streamGroupStub{channels: []string{"ch1", "ch2"}} + g2 := &streamGroupStub{channels: []string{"ch3"}} + e.Register(g1) + + iter := e.ChannelsIter() + e.Register(g2) + + var channels []string + for ch := range iter { + channels = append(channels, ch) + } + + expected := []string{"ch1", "ch2"} + if !slices.Equal(channels, expected) { + t.Fatalf("expected snapshot channels %v, got %v", expected, channels) + } +} + +func TestEngine_CacheConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + e, _ := api.New(api.WithCacheLimits(5*time.Minute, 10, 1024)) + + cfg := e.CacheConfig() + + if !cfg.Enabled { + t.Fatal("expected cache config to be enabled") + } + if cfg.TTL != 5*time.Minute { + t.Fatalf("expected TTL %v, got %v", 5*time.Minute, cfg.TTL) + } + if cfg.MaxEntries != 10 { + t.Fatalf("expected MaxEntries 10, got %d", cfg.MaxEntries) + } + if cfg.MaxBytes != 1024 { + t.Fatalf("expected MaxBytes 1024, got %d", cfg.MaxBytes) + } +} + +func TestEngine_RuntimeConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + broker := api.NewSSEBroker() + e, err := api.New( + api.WithSwagger("Runtime API", "Runtime snapshot", "1.2.3"), + api.WithSwaggerPath("/docs"), + api.WithCacheLimits(5*time.Minute, 10, 1024), + api.WithGraphQL(newTestSchema(), api.WithPlayground()), + api.WithI18n(api.I18nConfig{ + DefaultLocale: "en-GB", + Supported: []string{"en-GB", "fr"}, + }), + api.WithWSPath("/socket"), + api.WithSSE(broker), + api.WithSSEPath("/events"), + api.WithAuthentik(api.AuthentikConfig{ + Issuer: "https://auth.example.com", + ClientID: "runtime-client", + TrustedProxy: true, + PublicPaths: []string{"/public", "/docs"}, + }), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.RuntimeConfig() + + if !cfg.Swagger.Enabled { + t.Fatal("expected swagger snapshot to be enabled") + } + if cfg.Swagger.Path != "/docs" { + t.Fatalf("expected swagger path /docs, got %q", cfg.Swagger.Path) + } + if cfg.Transport.SwaggerPath != "/docs" { + t.Fatalf("expected transport swagger path /docs, got %q", cfg.Transport.SwaggerPath) + } + if cfg.Transport.GraphQLPlaygroundPath != "/graphql/playground" { + t.Fatalf("expected transport graphql playground path /graphql/playground, got %q", cfg.Transport.GraphQLPlaygroundPath) + } + if !cfg.Cache.Enabled || cfg.Cache.TTL != 5*time.Minute { + t.Fatalf("expected cache snapshot to be populated, got %+v", cfg.Cache) + } + if !cfg.GraphQL.Enabled { + t.Fatal("expected GraphQL snapshot to be enabled") + } + if cfg.GraphQL.Path != "/graphql" { + t.Fatalf("expected GraphQL path /graphql, got %q", cfg.GraphQL.Path) + } + if !cfg.GraphQL.Playground { + t.Fatal("expected GraphQL playground snapshot to be enabled") + } + if cfg.GraphQL.PlaygroundPath != "/graphql/playground" { + t.Fatalf("expected GraphQL playground path /graphql/playground, got %q", cfg.GraphQL.PlaygroundPath) + } + if cfg.I18n.DefaultLocale != "en-GB" { + t.Fatalf("expected default locale en-GB, got %q", cfg.I18n.DefaultLocale) + } + if !slices.Equal(cfg.I18n.Supported, []string{"en-GB", "fr"}) { + t.Fatalf("expected supported locales [en-GB fr], got %v", cfg.I18n.Supported) + } + if cfg.Authentik.Issuer != "https://auth.example.com" { + t.Fatalf("expected Authentik issuer https://auth.example.com, got %q", cfg.Authentik.Issuer) + } + if cfg.Authentik.ClientID != "runtime-client" { + t.Fatalf("expected Authentik client ID runtime-client, got %q", cfg.Authentik.ClientID) + } + if !cfg.Authentik.TrustedProxy { + t.Fatal("expected Authentik trusted proxy to be enabled") + } + if !slices.Equal(cfg.Authentik.PublicPaths, []string{"/public", "/docs"}) { + t.Fatalf("expected Authentik public paths [/public /docs], got %v", cfg.Authentik.PublicPaths) + } +} + +func TestEngine_RuntimeConfig_Good_EmptyOnNilEngine(t *testing.T) { + var e *api.Engine + + cfg := e.RuntimeConfig() + if cfg.Swagger.Enabled || cfg.Transport.SwaggerEnabled || cfg.GraphQL.Enabled || cfg.Cache.Enabled || cfg.I18n.DefaultLocale != "" || cfg.Authentik.Issuer != "" { + t.Fatalf("expected zero-value runtime config, got %+v", cfg) + } +} + +func TestEngine_AuthentikConfig_Good_SnapshotsCurrentSettings(t *testing.T) { + e, _ := api.New(api.WithAuthentik(api.AuthentikConfig{ + Issuer: "https://auth.example.com", + ClientID: "client", + TrustedProxy: true, + PublicPaths: []string{"/public", "/docs"}, + })) + + cfg := e.AuthentikConfig() + if cfg.Issuer != "https://auth.example.com" { + t.Fatalf("expected issuer https://auth.example.com, got %q", cfg.Issuer) + } + if cfg.ClientID != "client" { + t.Fatalf("expected client ID client, got %q", cfg.ClientID) + } + if !cfg.TrustedProxy { + t.Fatal("expected trusted proxy to be enabled") + } + if !slices.Equal(cfg.PublicPaths, []string{"/public", "/docs"}) { + t.Fatalf("expected public paths [/public /docs], got %v", cfg.PublicPaths) + } +} + +func TestEngine_AuthentikConfig_Good_ClonesPublicPaths(t *testing.T) { + publicPaths := []string{"/public", "/docs"} + e, _ := api.New(api.WithAuthentik(api.AuthentikConfig{ + Issuer: "https://auth.example.com", + PublicPaths: publicPaths, + })) + + cfg := e.AuthentikConfig() + publicPaths[0] = "/mutated" + + if cfg.PublicPaths[0] != "/public" { + t.Fatalf("expected snapshot to preserve original public paths, got %v", cfg.PublicPaths) + } +} + +func TestEngine_AuthentikConfig_Good_NormalisesPublicPaths(t *testing.T) { + e, _ := api.New(api.WithAuthentik(api.AuthentikConfig{ + PublicPaths: []string{" /public/ ", "docs", "/public"}, + })) + + cfg := e.AuthentikConfig() + expected := []string{"/public", "/docs"} + if !slices.Equal(cfg.PublicPaths, expected) { + t.Fatalf("expected normalised public paths %v, got %v", expected, cfg.PublicPaths) + } +} + +func TestEngine_AuthentikConfig_Good_BlankPublicPathsRemainNil(t *testing.T) { + e, _ := api.New(api.WithAuthentik(api.AuthentikConfig{ + PublicPaths: []string{" ", "\t", ""}, + })) + + cfg := e.AuthentikConfig() + if cfg.PublicPaths != nil { + t.Fatalf("expected blank public paths to collapse to nil, got %v", cfg.PublicPaths) + } +} + +func TestEngine_Register_Good_IgnoresNilGroups(t *testing.T) { + e, _ := api.New() + + var nilGroup *healthGroup + e.Register(nilGroup) + + g1 := &healthGroup{} + e.Register(g1) + + groups := e.Groups() + if len(groups) != 1 { + t.Fatalf("expected 1 registered group, got %d", len(groups)) + } + if groups[0].Name() != "health-extra" { + t.Fatalf("expected the original group to be preserved, got %q", groups[0].Name()) + } +} + func TestToolBridge_Iterators(t *testing.T) { b := api.NewToolBridge("/tools") desc := api.ToolDescriptor{Name: "test", Group: "g1"} @@ -76,6 +300,33 @@ func TestToolBridge_Iterators(t *testing.T) { } } +func TestToolBridge_Iterators_Good_SnapshotCurrentTools(t *testing.T) { + b := api.NewToolBridge("/tools") + b.Add(api.ToolDescriptor{Name: "first", Group: "g1"}, nil) + + toolsIter := b.ToolsIter() + descsIter := b.DescribeIter() + + b.Add(api.ToolDescriptor{Name: "second", Group: "g2"}, nil) + + var tools []api.ToolDescriptor + for tool := range toolsIter { + tools = append(tools, tool) + } + + var descs []api.RouteDescription + for desc := range descsIter { + descs = append(descs, desc) + } + + if len(tools) != 1 || tools[0].Name != "first" { + t.Fatalf("expected ToolsIter snapshot to contain the original tool, got %v", tools) + } + if len(descs) != 1 || descs[0].Path != "/first" { + t.Fatalf("expected DescribeIter snapshot to contain the original tool, got %v", descs) + } +} + func TestCodegen_SupportedLanguagesIter(t *testing.T) { var langs []string for l := range api.SupportedLanguagesIter() { diff --git a/openapi.go b/openapi.go index b98d8d1..ab2b27b 100644 --- a/openapi.go +++ b/openapi.go @@ -4,34 +4,236 @@ package api import ( "encoding/json" + "iter" + "net/http" + "sort" + "strconv" "strings" + "time" + "unicode" + + "slices" ) // SpecBuilder constructs an OpenAPI 3.1 specification from registered RouteGroups. +// Title, Summary, Description, Version, and optional contact/licence/terms metadata populate the +// OpenAPI info block. Top-level external documentation metadata is also supported, along with +// additive extension fields that describe runtime transport, cache, i18n, and Authentik settings. +// +// Example: +// +// builder := &api.SpecBuilder{Title: "Service", Version: "1.0.0"} +// spec, err := builder.Build(engine.Groups()) type SpecBuilder struct { - Title string - Description string - Version string + Title string + Summary string + Description string + Version string + SwaggerEnabled bool + SwaggerPath string + GraphQLEnabled bool + GraphQLPath string + GraphQLPlayground bool + GraphQLPlaygroundPath string + WSPath string + WSEnabled bool + SSEPath string + SSEEnabled bool + TermsOfService string + ContactName string + ContactURL string + ContactEmail string + Servers []string + LicenseName string + LicenseURL string + SecuritySchemes map[string]any + ExternalDocsDescription string + ExternalDocsURL string + PprofEnabled bool + ExpvarEnabled bool + CacheEnabled bool + CacheTTL string + CacheMaxEntries int + CacheMaxBytes int + I18nDefaultLocale string + I18nSupportedLocales []string + AuthentikIssuer string + AuthentikClientID string + AuthentikTrustedProxy bool + AuthentikPublicPaths []string +} + +type preparedRouteGroup struct { + name string + basePath string + descs []RouteDescription } +const openAPIDialect = "https://spec.openapis.org/oas/3.1/dialect/base" + // Build generates the complete OpenAPI 3.1 JSON spec. // Groups implementing DescribableGroup contribute endpoint documentation. // Other groups are listed as tags only. +// +// Example: +// +// data, err := (&api.SpecBuilder{Title: "Service", Version: "1.0.0"}).Build(engine.Groups()) func (sb *SpecBuilder) Build(groups []RouteGroup) ([]byte, error) { + if sb == nil { + sb = &SpecBuilder{} + } + sb = sb.snapshot() + + prepared := prepareRouteGroups(groups) + + info := map[string]any{ + "title": sb.Title, + "description": sb.Description, + "version": sb.Version, + } + if sb.Summary != "" { + info["summary"] = sb.Summary + } + spec := map[string]any{ - "openapi": "3.1.0", - "info": map[string]any{ - "title": sb.Title, - "description": sb.Description, - "version": sb.Version, - }, - "paths": sb.buildPaths(groups), - "tags": sb.buildTags(groups), + "openapi": "3.1.0", + "jsonSchemaDialect": openAPIDialect, + "info": info, + "paths": sb.buildPaths(prepared), + "tags": sb.buildTags(prepared), + } + + if sb.LicenseName != "" { + license := map[string]any{ + "name": sb.LicenseName, + } + if sb.LicenseURL != "" { + license["url"] = sb.LicenseURL + } + spec["info"].(map[string]any)["license"] = license + } + if swaggerPath := sb.effectiveSwaggerPath(); swaggerPath != "" { + spec["x-swagger-ui-path"] = normaliseSwaggerPath(swaggerPath) + } + if sb.SwaggerEnabled { + spec["x-swagger-enabled"] = true + } + if sb.GraphQLEnabled { + spec["x-graphql-enabled"] = true + } + if graphqlPath := sb.effectiveGraphQLPath(); graphqlPath != "" { + spec["x-graphql-path"] = normaliseOpenAPIPath(graphqlPath) + if sb.GraphQLPlayground { + spec["x-graphql-playground"] = true + } + } + if sb.GraphQLPlayground { + if playgroundPath := sb.effectiveGraphQLPlaygroundPath(); playgroundPath != "" { + spec["x-graphql-playground-path"] = normaliseOpenAPIPath(playgroundPath) + } + } + if wsPath := sb.effectiveWSPath(); wsPath != "" { + spec["x-ws-path"] = normaliseOpenAPIPath(wsPath) + } + if sb.WSEnabled { + spec["x-ws-enabled"] = true + } + if ssePath := sb.effectiveSSEPath(); ssePath != "" { + spec["x-sse-path"] = normaliseOpenAPIPath(ssePath) + } + if sb.SSEEnabled { + spec["x-sse-enabled"] = true + } + if sb.PprofEnabled { + spec["x-pprof-enabled"] = true + } + if sb.ExpvarEnabled { + spec["x-expvar-enabled"] = true + } + if sb.CacheEnabled { + spec["x-cache-enabled"] = true + } + if ttl := sb.effectiveCacheTTL(); ttl != "" { + spec["x-cache-ttl"] = ttl + } + if sb.CacheMaxEntries > 0 { + spec["x-cache-max-entries"] = sb.CacheMaxEntries + } + if sb.CacheMaxBytes > 0 { + spec["x-cache-max-bytes"] = sb.CacheMaxBytes + } + if locale := strings.TrimSpace(sb.I18nDefaultLocale); locale != "" { + spec["x-i18n-default-locale"] = locale + } + if len(sb.I18nSupportedLocales) > 0 { + spec["x-i18n-supported-locales"] = slices.Clone(sb.I18nSupportedLocales) + } + if issuer := strings.TrimSpace(sb.AuthentikIssuer); issuer != "" { + spec["x-authentik-issuer"] = issuer + } + if clientID := strings.TrimSpace(sb.AuthentikClientID); clientID != "" { + spec["x-authentik-client-id"] = clientID + } + if sb.AuthentikTrustedProxy { + spec["x-authentik-trusted-proxy"] = true + } + if paths := sb.effectiveAuthentikPublicPaths(); len(paths) > 0 { + spec["x-authentik-public-paths"] = paths + } + + if sb.TermsOfService != "" { + spec["info"].(map[string]any)["termsOfService"] = sb.TermsOfService + } + + if sb.ContactName != "" || sb.ContactURL != "" || sb.ContactEmail != "" { + contact := map[string]any{} + if sb.ContactName != "" { + contact["name"] = sb.ContactName + } + if sb.ContactURL != "" { + contact["url"] = sb.ContactURL + } + if sb.ContactEmail != "" { + contact["email"] = sb.ContactEmail + } + spec["info"].(map[string]any)["contact"] = contact + } + + if servers := normaliseServers(sb.Servers); len(servers) > 0 { + out := make([]map[string]any, 0, len(servers)) + for _, server := range servers { + out = append(out, map[string]any{"url": server}) + } + spec["servers"] = out + } + + if sb.ExternalDocsURL != "" { + externalDocs := map[string]any{ + "url": sb.ExternalDocsURL, + } + if sb.ExternalDocsDescription != "" { + externalDocs["description"] = sb.ExternalDocsDescription + } + spec["externalDocs"] = externalDocs } // Add component schemas for the response envelope. spec["components"] = map[string]any{ "schemas": map[string]any{ + "Response": map[string]any{ + "type": "object", + "properties": map[string]any{ + "success": map[string]any{"type": "boolean"}, + "data": map[string]any{}, + "error": map[string]any{ + "$ref": "#/components/schemas/Error", + }, + "meta": map[string]any{ + "$ref": "#/components/schemas/Meta", + }, + }, + "required": []string{"success"}, + }, "Error": map[string]any{ "type": "object", "properties": map[string]any{ @@ -52,13 +254,33 @@ func (sb *SpecBuilder) Build(groups []RouteGroup) ([]byte, error) { }, }, }, + "securitySchemes": securitySchemeComponents(sb.SecuritySchemes), + "headers": deprecationHeaderComponents(), + "responses": responseComponents(), } return json.MarshalIndent(spec, "", " ") } +// BuildIter generates the complete OpenAPI 3.1 JSON spec from a route-group +// iterator. The iterator is snapshotted before building so the result stays +// stable even if the source changes during rendering. +// +// Example: +// +// data, err := (&api.SpecBuilder{Title: "Service"}).BuildIter(api.RegisteredSpecGroupsIter()) +func (sb *SpecBuilder) BuildIter(groups iter.Seq[RouteGroup]) ([]byte, error) { + if sb == nil { + sb = &SpecBuilder{} + } + + return sb.Build(collectRouteGroups(groups)) +} + // buildPaths generates the paths object from all DescribableGroups. -func (sb *SpecBuilder) buildPaths(groups []RouteGroup) map[string]any { +func (sb *SpecBuilder) buildPaths(groups []preparedRouteGroup) map[string]any { + operationIDs := map[string]int{} + publicPaths := sb.effectiveAuthentikPublicPaths() paths := map[string]any{ // Built-in health endpoint. "/health": map[string]any{ @@ -66,61 +288,131 @@ func (sb *SpecBuilder) buildPaths(groups []RouteGroup) map[string]any { "summary": "Health check", "description": "Returns server health status", "tags": []string{"system"}, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Server is healthy", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": envelopeSchema(map[string]any{"type": "string"}), - }, - }, - }, - }, + "operationId": operationID("get", "/health", operationIDs), + "responses": healthResponses(), }, }, } - for _, g := range groups { - dg, ok := g.(DescribableGroup) - if !ok { - continue + graphqlPath := sb.effectiveGraphQLPath() + if graphqlPath != "" { + graphqlPath = normaliseOpenAPIPath(graphqlPath) + item := graphqlPathItem(graphqlPath, operationIDs) + if isPublicPathForList(graphqlPath, publicPaths) { + makePathItemPublic(item) + } + paths[graphqlPath] = item + if sb.GraphQLPlayground { + playgroundPath := sb.effectiveGraphQLPlaygroundPath() + if playgroundPath == "" { + playgroundPath = graphqlPath + "/playground" + } + playgroundPath = normaliseOpenAPIPath(playgroundPath) + item := graphqlPlaygroundPathItem(playgroundPath, operationIDs) + if isPublicPathForList(playgroundPath, publicPaths) { + makePathItemPublic(item) + } + paths[playgroundPath] = item + } + } + + if wsPath := sb.effectiveWSPath(); wsPath != "" { + wsPath = normaliseOpenAPIPath(wsPath) + item := wsPathItem(wsPath, operationIDs) + if isPublicPathForList(wsPath, publicPaths) { + makePathItemPublic(item) + } + paths[wsPath] = item + } + + if ssePath := sb.effectiveSSEPath(); ssePath != "" { + ssePath = normaliseOpenAPIPath(ssePath) + item := ssePathItem(ssePath, operationIDs) + if isPublicPathForList(ssePath, publicPaths) { + makePathItemPublic(item) + } + paths[ssePath] = item + } + + if sb.PprofEnabled { + item := pprofPathItem(operationIDs) + if isPublicPathForList("/debug/pprof", publicPaths) { + makePathItemPublic(item) + } + paths["/debug/pprof"] = item + } + + if sb.ExpvarEnabled { + item := expvarPathItem(operationIDs) + if isPublicPathForList("/debug/vars", publicPaths) { + makePathItemPublic(item) } - for _, rd := range dg.Describe() { - fullPath := g.BasePath() + rd.Path + paths["/debug/vars"] = item + } + + for _, g := range groups { + for _, rd := range g.descs { + fullPath := joinOpenAPIPath(g.basePath, rd.Path) method := strings.ToLower(rd.Method) + deprecated := rd.Deprecated || strings.TrimSpace(rd.SunsetDate) != "" || strings.TrimSpace(rd.Replacement) != "" + deprecationHeaders := deprecationResponseHeaders(deprecated, rd.SunsetDate, rd.Replacement) + isPublic := isPublicPathForList(fullPath, publicPaths) + security := rd.Security + if isPublic { + security = []map[string][]string{} + } operation := map[string]any{ "summary": rd.Summary, "description": rd.Description, - "tags": rd.Tags, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Successful response", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": envelopeSchema(rd.Response), - }, - }, - }, - "400": map[string]any{ - "description": "Bad request", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": envelopeSchema(nil), - }, - }, + "operationId": operationID(method, fullPath, operationIDs), + "responses": operationResponses(method, rd.StatusCode, rd.Response, rd.ResponseExample, rd.ResponseHeaders, security, deprecated, rd.SunsetDate, rd.Replacement, deprecationHeaders, sb.CacheEnabled), + } + if deprecated { + operation["deprecated"] = true + } + if isPublic { + operation["security"] = []any{} + } else if security != nil { + operation["security"] = security + } else { + operation["security"] = []any{ + map[string]any{ + "bearerAuth": []any{}, }, - }, + } + } + if tags := resolvedOperationTags(g.name, rd); len(tags) > 0 { + operation["tags"] = tags + } + + if params := pathParameters(fullPath); len(params) > 0 { + operation["parameters"] = params + } + if explicit := operationParameters(rd.Parameters); len(explicit) > 0 { + operation["parameters"] = mergeOperationParameters(operation["parameters"], explicit) } // Add request body for methods that accept one. - if rd.RequestBody != nil && (method == "post" || method == "put" || method == "patch") { + // The contract only excludes GET; other verbs may legitimately carry bodies. + // An example-only request body still produces a documented payload so + // callers can see the expected shape even when a schema is omitted. + if method != "get" && (rd.RequestBody != nil || rd.RequestExample != nil) { + requestSchema := rd.RequestBody + if requestSchema == nil { + requestSchema = map[string]any{} + } + requestMediaType := map[string]any{ + "schema": requestSchema, + } + if rd.RequestExample != nil { + requestMediaType["example"] = rd.RequestExample + } + operation["requestBody"] = map[string]any{ "required": true, "content": map[string]any{ - "application/json": map[string]any{ - "schema": rd.RequestBody, - }, + "application/json": requestMediaType, }, } } @@ -136,49 +428,1786 @@ func (sb *SpecBuilder) buildPaths(groups []RouteGroup) map[string]any { } } + // The built-in health check remains public, so override the inherited + // default security requirement with an explicit empty array. + if health, ok := paths["/health"].(map[string]any); ok { + if op, ok := health["get"].(map[string]any); ok { + op["security"] = []any{} + } + } + return paths } -// buildTags generates the tags array from all RouteGroups. -func (sb *SpecBuilder) buildTags(groups []RouteGroup) []map[string]any { - tags := []map[string]any{ - {"name": "system", "description": "System endpoints"}, +// joinOpenAPIPath normalises a base path and relative route path into a single +// OpenAPI path without duplicate or missing separators. Gin-style parameters +// such as :id and *path are converted to OpenAPI template parameters. +func joinOpenAPIPath(basePath, routePath string) string { + basePath = strings.TrimSpace(basePath) + routePath = strings.TrimSpace(routePath) + + if basePath == "" { + basePath = "/" + } + if routePath == "" || routePath == "/" { + return normaliseOpenAPIPath(basePath) } - seen := map[string]bool{"system": true} - for _, g := range groups { - name := g.Name() - if !seen[name] { - tags = append(tags, map[string]any{ - "name": name, - "description": name + " endpoints", - }) - seen[name] = true + basePath = normaliseOpenAPIPath(basePath) + routePath = normaliseOpenAPIPath(routePath) + + if basePath == "/" { + return routePath + } + + return strings.TrimRight(basePath, "/") + "/" + strings.TrimPrefix(routePath, "/") +} + +// normaliseOpenAPIPath trims whitespace and collapses trailing separators +// while preserving the root path and converting Gin-style path parameters. +func normaliseOpenAPIPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "/" + } + + segments := strings.Split(path, "/") + cleaned := make([]string, 0, len(segments)) + for _, segment := range segments { + segment = strings.TrimSpace(segment) + if segment == "" { + continue + } + switch { + case strings.HasPrefix(segment, ":") && len(segment) > 1: + segment = "{" + segment[1:] + "}" + case strings.HasPrefix(segment, "*") && len(segment) > 1: + segment = "{" + segment[1:] + "}" } + cleaned = append(cleaned, segment) } - return tags + if len(cleaned) == 0 { + return "/" + } + + return "/" + strings.Join(cleaned, "/") } -// envelopeSchema wraps a data schema in the standard Response[T] envelope. -func envelopeSchema(dataSchema map[string]any) map[string]any { - properties := map[string]any{ - "success": map[string]any{"type": "boolean"}, - "error": map[string]any{ - "$ref": "#/components/schemas/Error", +// operationResponses builds the standard response set for a documented API +// operation. The framework always exposes the common envelope responses, plus +// middleware-driven 429 and 504 errors. +func operationResponses(method string, statusCode int, dataSchema map[string]any, example any, responseHeaders map[string]string, security []map[string][]string, deprecated bool, sunsetDate, replacement string, deprecationHeaders map[string]any, cacheEnabled bool) map[string]any { + documentedHeaders := documentedResponseHeaders(responseHeaders) + successHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders(), deprecationHeaders, documentedHeaders) + if method == "get" && cacheEnabled { + successHeaders = mergeHeaders(successHeaders, cacheSuccessHeaders()) + } + + isPublic := security != nil && len(security) == 0 + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders(), deprecationHeaders, documentedHeaders) + + code := successStatusCode(statusCode) + if dataSchema == nil && example != nil { + dataSchema = map[string]any{} + } + successResponse := map[string]any{ + "description": successResponseDescription(code), + "headers": successHeaders, + } + if !isNoContentStatus(code) { + content := map[string]any{ + "schema": envelopeSchema(dataSchema), + } + if example != nil { + // Example payloads are optional, but when a route provides one we + // expose it alongside the schema so generated docs stay useful. + content["example"] = example + } + + successResponse["content"] = map[string]any{ + "application/json": content, + } + } + + responses := map[string]any{ + strconv.Itoa(code): successResponse, + "400": map[string]any{ + "description": "Bad request", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, }, - "meta": map[string]any{ - "$ref": "#/components/schemas/Meta", + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders(), deprecationHeaders, documentedHeaders), + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, }, } - if dataSchema != nil { - properties["data"] = dataSchema + if deprecated && (strings.TrimSpace(sunsetDate) != "" || strings.TrimSpace(replacement) != "") { + responses["410"] = map[string]any{ + "description": "Gone", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, + } + } + + if !isPublic { + responses["401"] = map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, + } + responses["403"] = map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": errorHeaders, + } + } + + return responses +} + +func successStatusCode(statusCode int) int { + if statusCode < 200 || statusCode > 299 { + return http.StatusOK + } + if statusCode == 0 { + return http.StatusOK + } + return statusCode +} + +func isNoContentStatus(statusCode int) bool { + switch statusCode { + case http.StatusNoContent, http.StatusResetContent: + return true + default: + return false + } +} + +func successResponseDescription(statusCode int) string { + switch statusCode { + case http.StatusCreated: + return "Created" + case http.StatusAccepted: + return "Accepted" + case http.StatusNoContent: + return "No content" + case http.StatusResetContent: + return "Reset content" + default: + return "Successful response" } +} +// healthResponses builds the response set for the built-in health endpoint. +// It stays public, but rate limiting and timeouts can still apply. +func healthResponses() map[string]any { return map[string]any{ - "type": "object", - "properties": properties, - "required": []string{"success"}, + "200": map[string]any{ + "description": "Server is healthy", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(map[string]any{"type": "string"}), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders(), cacheSuccessHeaders()), + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()), + }, + } +} + +// deprecationResponseHeaders documents the standard deprecation headers for +// deprecated or sunsetted operations. +func deprecationResponseHeaders(deprecated bool, sunsetDate, replacement string) map[string]any { + sunsetDate = strings.TrimSpace(sunsetDate) + replacement = strings.TrimSpace(replacement) + + if !deprecated && sunsetDate == "" && replacement == "" { + return nil + } + + headers := map[string]any{ + "Deprecation": map[string]any{ + "$ref": "#/components/headers/deprecation", + }, + "X-API-Warn": map[string]any{ + "$ref": "#/components/headers/xapiwarn", + }, + } + + if sunsetDate != "" { + headers["Sunset"] = map[string]any{ + "$ref": "#/components/headers/sunset", + } + } + + if replacement != "" { + headers["Link"] = map[string]any{ + "$ref": "#/components/headers/link", + } + } + + return headers +} + +// deprecationHeaderComponents returns reusable OpenAPI header components for +// the standard deprecation and sunset middleware headers. +func deprecationHeaderComponents() map[string]any { + return map[string]any{ + "deprecation": map[string]any{ + "description": "Indicates that the endpoint is deprecated.", + "schema": map[string]any{ + "type": "string", + "enum": []string{"true"}, + }, + }, + "sunset": map[string]any{ + "description": "The date and time after which the endpoint will no longer be supported.", + "schema": map[string]any{ + "type": "string", + "format": "date-time", + }, + }, + "link": map[string]any{ + "description": "Reference to the successor endpoint, when one is provided.", + "schema": map[string]any{ + "type": "string", + }, + }, + "xapiwarn": map[string]any{ + "description": "Human-readable deprecation warning for clients.", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +// responseComponents returns reusable OpenAPI response objects for the +// common error cases exposed by the framework. The path operations still +// inline their concrete headers so existing callers keep the same output, +// but these components make the response catalogue available for reuse. +func responseComponents() map[string]any { + return map[string]any{ + "BadRequest": map[string]any{ + "description": "Bad request", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": standardResponseHeaders(), + }, + "Unauthorized": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": standardResponseHeaders(), + }, + "Forbidden": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": standardResponseHeaders(), + }, + "RateLimitExceeded": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "GatewayTimeout": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": standardResponseHeaders(), + }, + "InternalServerError": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": standardResponseHeaders(), + }, + "Gone": map[string]any{ + "description": "Gone", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": envelopeSchema(nil), + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), deprecationResponseHeaders(true, "", "")), + }, + } +} + +// securitySchemeComponents builds the OpenAPI security scheme registry. +// bearerAuth stays available by default, while callers can add or override +// additional scheme definitions for custom security requirements. +func securitySchemeComponents(overrides map[string]any) map[string]any { + schemes := map[string]any{ + "bearerAuth": map[string]any{ + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + }, + } + + for name, scheme := range overrides { + name = strings.TrimSpace(name) + if name == "" || scheme == nil { + continue + } + schemes[name] = cloneOpenAPIValue(scheme) + } + + return schemes +} + +// buildTags generates the tags array from all RouteGroups. +func (sb *SpecBuilder) buildTags(groups []preparedRouteGroup) []map[string]any { + tags := []map[string]any{ + {"name": "system", "description": "System endpoints"}, + } + seen := map[string]bool{"system": true} + + if graphqlPath := sb.effectiveGraphQLPath(); graphqlPath != "" && !seen["graphql"] { + tags = append(tags, map[string]any{ + "name": "graphql", + "description": "GraphQL endpoints", + }) + seen["graphql"] = true + } + + if ssePath := sb.effectiveSSEPath(); ssePath != "" && !seen["events"] { + tags = append(tags, map[string]any{ + "name": "events", + "description": "Server-Sent Events endpoints", + }) + seen["events"] = true + } + + if (sb.PprofEnabled || sb.ExpvarEnabled) && !seen["debug"] { + tags = append(tags, map[string]any{ + "name": "debug", + "description": "Runtime debug endpoints", + }) + seen["debug"] = true + } + + for _, g := range groups { + name := strings.TrimSpace(g.name) + if name != "" && !seen[name] { + tags = append(tags, map[string]any{ + "name": name, + "description": name + " endpoints", + }) + seen[name] = true + } + + for _, rd := range g.descs { + for _, tag := range rd.Tags { + tag = strings.TrimSpace(tag) + if tag == "" || seen[tag] { + continue + } + tags = append(tags, map[string]any{ + "name": tag, + "description": tag + " endpoints", + }) + seen[tag] = true + } + } + } + + sortTags(tags) + + return tags +} + +// sortTags keeps system first and orders the remaining tags alphabetically so +// generated specs stay stable across registration order changes. +func sortTags(tags []map[string]any) { + if len(tags) < 2 { + return + } + + sort.SliceStable(tags, func(i, j int) bool { + left, _ := tags[i]["name"].(string) + right, _ := tags[j]["name"].(string) + + switch { + case left == "system": + return true + case right == "system": + return false + default: + return left < right + } + }) +} + +func graphqlPathItem(path string, operationIDs map[string]int) map[string]any { + return map[string]any{ + "get": map[string]any{ + "summary": "GraphQL query", + "description": "Executes GraphQL queries over GET using query parameters", + "tags": []string{"graphql"}, + "operationId": operationID("get", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "parameters": graphqlQueryParameters(), + "responses": graphqlResponses(), + }, + "post": map[string]any{ + "summary": "GraphQL query", + "description": "Executes GraphQL queries and mutations", + "tags": []string{"graphql"}, + "operationId": operationID("post", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "requestBody": map[string]any{ + "required": true, + "content": map[string]any{ + "application/json": map[string]any{ + "schema": graphqlRequestSchema(), + }, + }, + }, + "responses": graphqlResponses(), + }, + } +} + +func graphqlPlaygroundPathItem(path string, operationIDs map[string]int) map[string]any { + return map[string]any{ + "get": map[string]any{ + "summary": "GraphQL playground", + "description": "Interactive GraphQL IDE for the configured schema", + "tags": []string{"graphql"}, + "operationId": operationID("get", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "responses": graphqlPlaygroundResponses(), + }, + } +} + +func wsPathItem(path string, operationIDs map[string]int) map[string]any { + return map[string]any{ + "get": map[string]any{ + "summary": "WebSocket connection", + "description": "Upgrades the connection to a WebSocket stream", + "tags": []string{"system"}, + "operationId": operationID("get", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "responses": wsResponses(), + }, + } +} + +func ssePathItem(path string, operationIDs map[string]int) map[string]any { + return map[string]any{ + "get": map[string]any{ + "summary": "Server-Sent Events stream", + "description": "Streams published events as text/event-stream", + "tags": []string{"events"}, + "operationId": operationID("get", path, operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "parameters": []map[string]any{ + { + "name": "channel", + "in": "query", + "required": false, + "description": "Restrict the stream to a specific channel", + "schema": map[string]any{ + "type": "string", + }, + }, + }, + "responses": sseResponses(), + }, + } +} + +func wsResponses() map[string]any { + successHeaders := mergeHeaders( + standardResponseHeaders(), + rateLimitSuccessHeaders(), + wsUpgradeHeaders(), + ) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "101": map[string]any{ + "description": "Switching protocols", + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + } +} + +func wsUpgradeHeaders() map[string]any { + return map[string]any{ + "Upgrade": map[string]any{ + "description": "Indicates that the connection has switched to WebSocket", + "schema": map[string]any{ + "type": "string", + }, + }, + "Connection": map[string]any{ + "description": "Keeps the upgraded connection open", + "schema": map[string]any{ + "type": "string", + }, + }, + "Sec-WebSocket-Accept": map[string]any{ + "description": "Validates the WebSocket handshake", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +func pprofPathItem(operationIDs map[string]int) map[string]any { + successHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "get": map[string]any{ + "summary": "pprof index", + "description": "Lists the available Go runtime profiles", + "tags": []string{"debug"}, + "operationId": operationID("get", "/debug/pprof", operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "pprof index", + "content": map[string]any{ + "text/html": map[string]any{ + "schema": map[string]any{ + "type": "string", + }, + }, + }, + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + }, + }, + } +} + +func expvarPathItem(operationIDs map[string]int) map[string]any { + successHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "get": map[string]any{ + "summary": "Runtime metrics", + "description": "Returns expvar metrics as JSON", + "tags": []string{"debug"}, + "operationId": operationID("get", "/debug/vars", operationIDs), + "security": []any{ + map[string]any{ + "bearerAuth": []any{}, + }, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "Runtime metrics", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + }, + }, + } +} + +func graphqlRequestSchema() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + }, + "variables": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + "operationName": map[string]any{ + "type": "string", + }, + }, + "required": []string{"query"}, + } +} + +func graphqlQueryParameters() []map[string]any { + return []map[string]any{ + { + "name": "query", + "in": "query", + "required": true, + "description": "GraphQL query or mutation document", + "schema": map[string]any{ + "type": "string", + }, + }, + { + "name": "variables", + "in": "query", + "required": false, + "description": "JSON-encoded GraphQL variables", + "schema": map[string]any{ + "type": "string", + }, + }, + { + "name": "operationName", + "in": "query", + "required": false, + "description": "Operation name to execute", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +func graphqlResponses() map[string]any { + successHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders(), cacheSuccessHeaders()) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "200": map[string]any{ + "description": "GraphQL response", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": successHeaders, + }, + "400": map[string]any{ + "description": "Bad request", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + } +} + +func graphqlPlaygroundResponses() map[string]any { + successHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "200": map[string]any{ + "description": "GraphQL playground", + "content": map[string]any{ + "text/html": map[string]any{ + "schema": map[string]any{ + "type": "string", + }, + }, + }, + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + } +} + +func sseResponses() map[string]any { + successHeaders := mergeHeaders( + standardResponseHeaders(), + rateLimitSuccessHeaders(), + sseResponseHeaders(), + ) + errorHeaders := mergeHeaders(standardResponseHeaders(), rateLimitSuccessHeaders()) + + return map[string]any{ + "200": map[string]any{ + "description": "Event stream", + "content": map[string]any{ + "text/event-stream": map[string]any{ + "schema": map[string]any{ + "type": "string", + }, + }, + }, + "headers": successHeaders, + }, + "401": map[string]any{ + "description": "Unauthorised", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "403": map[string]any{ + "description": "Forbidden", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "429": map[string]any{ + "description": "Too many requests", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": mergeHeaders(standardResponseHeaders(), rateLimitHeaders()), + }, + "500": map[string]any{ + "description": "Internal server error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + "504": map[string]any{ + "description": "Gateway timeout", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + "headers": errorHeaders, + }, + } +} + +// prepareRouteGroups snapshots route descriptions once per group so iterator- +// backed implementations can be consumed safely by both tag and path builders. +func prepareRouteGroups(groups []RouteGroup) []preparedRouteGroup { + if len(groups) == 0 { + return nil + } + + out := make([]preparedRouteGroup, 0, len(groups)) + for _, g := range groups { + if g == nil || isNilRouteGroup(g) { + continue + } + if isHiddenRouteGroup(g) { + continue + } + out = append(out, preparedRouteGroup{ + name: g.Name(), + basePath: g.BasePath(), + descs: collectRouteDescriptions(g), + }) + } + + return out +} + +func collectRouteGroups(groups iter.Seq[RouteGroup]) []RouteGroup { + if groups == nil { + return nil + } + + out := make([]RouteGroup, 0) + for group := range groups { + out = append(out, group) + } + + return out +} + +func collectRouteDescriptions(g RouteGroup) []RouteDescription { + descIter := routeDescriptions(g) + if descIter == nil { + return nil + } + + descs := make([]RouteDescription, 0) + for rd := range descIter { + if rd.Hidden { + continue + } + descs = append(descs, cloneRouteDescription(rd)) + } + + return descs +} + +func isHiddenRouteGroup(g RouteGroup) bool { + type hiddenRouteGroup interface { + Hidden() bool + } + + hg, ok := g.(hiddenRouteGroup) + return ok && hg.Hidden() +} + +// routeDescriptions returns OpenAPI route descriptions for a group. +// Iterator-backed implementations are preferred when available so builders +// can avoid slice allocation. +func routeDescriptions(g RouteGroup) iter.Seq[RouteDescription] { + if dg, ok := g.(DescribableGroupIter); ok { + if descIter := dg.DescribeIter(); descIter != nil { + return descIter + } + } + if dg, ok := g.(DescribableGroup); ok { + descs := dg.Describe() + return func(yield func(RouteDescription) bool) { + for _, rd := range descs { + if !yield(rd) { + return + } + } + } + } + return nil +} + +// pathParameters extracts unique OpenAPI path parameters from a path template. +// Parameters are returned in the order they appear in the path. +func pathParameters(path string) []map[string]any { + const ( + open = '{' + close = '}' + ) + + seen := map[string]bool{} + params := make([]map[string]any, 0) + + for i := 0; i < len(path); i++ { + if path[i] != open { + continue + } + end := strings.IndexByte(path[i+1:], close) + if end < 0 { + continue + } + name := path[i+1 : i+1+end] + if name == "" || strings.ContainsAny(name, "/{}") || seen[name] { + continue + } + seen[name] = true + params = append(params, map[string]any{ + "name": name, + "in": "path", + "required": true, + "schema": map[string]any{ + "type": "string", + }, + }) + i += end + 1 + } + + return params +} + +// operationParameters converts explicit route parameter descriptions into +// OpenAPI parameter objects. +func operationParameters(params []ParameterDescription) []map[string]any { + if len(params) == 0 { + return nil + } + + out := make([]map[string]any, 0, len(params)) + for _, param := range params { + if param.Name == "" || param.In == "" { + continue + } + + entry := map[string]any{ + "name": param.Name, + "in": param.In, + "required": param.Required || param.In == "path", + } + if param.Description != "" { + entry["description"] = param.Description + } + if param.Deprecated { + entry["deprecated"] = true + } + if len(param.Schema) > 0 { + entry["schema"] = param.Schema + } else if param.In == "path" || param.In == "query" || param.In == "header" || param.In == "cookie" { + entry["schema"] = map[string]any{"type": "string"} + } + if param.Example != nil { + entry["example"] = param.Example + } + + out = append(out, entry) + } + + return out +} + +// mergeOperationParameters combines generated and explicit parameter +// definitions, letting explicit entries override auto-generated path params. +func mergeOperationParameters(existing any, explicit []map[string]any) []map[string]any { + merged := make([]map[string]any, 0, len(explicit)) + index := map[string]int{} + + add := func(param map[string]any) { + name, _ := param["name"].(string) + in, _ := param["in"].(string) + if name == "" || in == "" { + return + } + key := in + ":" + name + if pos, ok := index[key]; ok { + merged[pos] = param + return + } + index[key] = len(merged) + merged = append(merged, param) + } + + if params, ok := existing.([]map[string]any); ok { + for _, param := range params { + add(param) + } + } + + for _, param := range explicit { + add(param) + } + + if len(merged) == 0 { + return nil + } + + return merged +} + +// resolvedOperationTags returns the explicit route tags when provided, or a +// stable fallback derived from the group's name when the route omits tags. +func resolvedOperationTags(groupName string, rd RouteDescription) []string { + if tags := cleanTags(rd.Tags); len(tags) > 0 { + return tags + } + + if name := strings.TrimSpace(groupName); name != "" { + return []string{name} + } + + return nil +} + +// cleanTags trims whitespace and removes empty or duplicate tags while +// preserving the first occurrence of each name. +func cleanTags(tags []string) []string { + if len(tags) == 0 { + return nil + } + + cleaned := make([]string, 0, len(tags)) + seen := make(map[string]struct{}, len(tags)) + for _, tag := range tags { + tag = strings.TrimSpace(tag) + if tag == "" { + continue + } + if _, ok := seen[tag]; ok { + continue + } + seen[tag] = struct{}{} + cleaned = append(cleaned, tag) + } + if len(cleaned) == 0 { + return nil + } + return cleaned +} + +// envelopeSchema wraps a data schema in the standard Response[T] envelope. +func envelopeSchema(dataSchema map[string]any) map[string]any { + properties := map[string]any{ + "success": map[string]any{"type": "boolean"}, + "error": map[string]any{ + "$ref": "#/components/schemas/Error", + }, + "meta": map[string]any{ + "$ref": "#/components/schemas/Meta", + }, + } + + if dataSchema != nil { + properties["data"] = dataSchema + } + + return map[string]any{ + "type": "object", + "properties": properties, + "required": []string{"success"}, + } +} + +// rateLimitHeaders documents the response headers emitted when rate limiting +// rejects a request. +func rateLimitHeaders() map[string]any { + return map[string]any{ + "X-RateLimit-Limit": map[string]any{ + "description": "Maximum number of requests allowed in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "X-RateLimit-Remaining": map[string]any{ + "description": "Number of requests remaining in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 0, + }, + }, + "X-RateLimit-Reset": map[string]any{ + "description": "Unix timestamp when the rate limit window resets", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "Retry-After": map[string]any{ + "description": "Seconds until the rate limit resets", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + } +} + +// rateLimitSuccessHeaders documents the response headers emitted on +// successful requests when rate limiting is enabled. +func rateLimitSuccessHeaders() map[string]any { + return map[string]any{ + "X-RateLimit-Limit": map[string]any{ + "description": "Maximum number of requests allowed in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + "X-RateLimit-Remaining": map[string]any{ + "description": "Number of requests remaining in the current window", + "schema": map[string]any{ + "type": "integer", + "minimum": 0, + }, + }, + "X-RateLimit-Reset": map[string]any{ + "description": "Unix timestamp when the rate limit window resets", + "schema": map[string]any{ + "type": "integer", + "minimum": 1, + }, + }, + } +} + +// cacheSuccessHeaders documents the response header emitted on successful +// cache hits. +func cacheSuccessHeaders() map[string]any { + return map[string]any{ + "X-Cache": map[string]any{ + "description": "Indicates the response was served from the in-memory cache", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +// sseResponseHeaders documents the response headers emitted by the SSE stream. +func sseResponseHeaders() map[string]any { + return map[string]any{ + "Cache-Control": map[string]any{ + "description": "Prevents intermediaries from caching the event stream", + "schema": map[string]any{ + "type": "string", + }, + }, + "Connection": map[string]any{ + "description": "Keeps the HTTP connection open for streaming", + "schema": map[string]any{ + "type": "string", + }, + }, + "X-Accel-Buffering": map[string]any{ + "description": "Disables buffering in compatible reverse proxies", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +// effectiveGraphQLPath returns the configured GraphQL path or the default +// GraphQL path when GraphQL is enabled without an explicit path. Returns an +// empty string when neither GraphQL nor the playground is enabled. +func (sb *SpecBuilder) effectiveGraphQLPath() string { + if !sb.GraphQLEnabled && !sb.GraphQLPlayground { + return "" + } + graphqlPath := strings.TrimSpace(sb.GraphQLPath) + if graphqlPath == "" { + return defaultGraphQLPath + } + return graphqlPath +} + +// effectiveGraphQLPlaygroundPath returns the configured playground path when +// GraphQL playground is enabled. +func (sb *SpecBuilder) effectiveGraphQLPlaygroundPath() string { + if !sb.GraphQLPlayground { + return "" + } + + path := strings.TrimSpace(sb.GraphQLPlaygroundPath) + if path != "" { + return path + } + + base := sb.effectiveGraphQLPath() + if base == "" { + base = defaultGraphQLPath + } + + return base + "/playground" +} + +// effectiveSwaggerPath returns the configured Swagger UI path or the default +// path when Swagger is enabled without an explicit override. Returns an empty +// string when Swagger is disabled. +func (sb *SpecBuilder) effectiveSwaggerPath() string { + if !sb.SwaggerEnabled { + return "" + } + swaggerPath := strings.TrimSpace(sb.SwaggerPath) + if swaggerPath == "" { + return defaultSwaggerPath + } + return swaggerPath +} + +// effectiveWSPath returns the configured WebSocket path or the default path +// when WebSockets are enabled without an explicit override. Returns an empty +// string when WebSockets are disabled. +func (sb *SpecBuilder) effectiveWSPath() string { + if !sb.WSEnabled { + return "" + } + wsPath := strings.TrimSpace(sb.WSPath) + if wsPath == "" { + return defaultWSPath + } + return wsPath +} + +// effectiveSSEPath returns the configured SSE path or the default path when +// SSE is enabled without an explicit override. Returns an empty string when +// SSE is disabled. +func (sb *SpecBuilder) effectiveSSEPath() string { + if !sb.SSEEnabled { + return "" + } + ssePath := strings.TrimSpace(sb.SSEPath) + if ssePath == "" { + return defaultSSEPath + } + return ssePath +} + +// effectiveCacheTTL returns a normalised cache TTL when it parses to a +// positive duration. +func (sb *SpecBuilder) effectiveCacheTTL() string { + ttl := strings.TrimSpace(sb.CacheTTL) + if ttl == "" { + return "" + } + + d, err := time.ParseDuration(ttl) + if err != nil || d <= 0 { + return "" + } + + return ttl +} + +// effectiveAuthentikPublicPaths returns the public paths that Authentik skips +// in practice, including the always-public health and Swagger endpoints. +func (sb *SpecBuilder) effectiveAuthentikPublicPaths() []string { + if !sb.hasAuthentikMetadata() { + return nil + } + + paths := []string{"/health"} + if swaggerPath := sb.effectiveSwaggerPath(); swaggerPath != "" { + paths = append(paths, swaggerPath) + } + paths = append(paths, sb.AuthentikPublicPaths...) + return normalisePublicPaths(paths) +} + +// snapshot returns a trimmed copy of the builder so Build operates on stable +// input even when callers reuse or mutate their original configuration. +func (sb *SpecBuilder) snapshot() *SpecBuilder { + if sb == nil { + return &SpecBuilder{} + } + + out := *sb + out.Title = strings.TrimSpace(out.Title) + out.Summary = strings.TrimSpace(out.Summary) + out.Description = strings.TrimSpace(out.Description) + out.Version = strings.TrimSpace(out.Version) + out.SwaggerPath = strings.TrimSpace(out.SwaggerPath) + out.GraphQLPath = strings.TrimSpace(out.GraphQLPath) + out.GraphQLPlaygroundPath = strings.TrimSpace(out.GraphQLPlaygroundPath) + out.WSPath = strings.TrimSpace(out.WSPath) + out.SSEPath = strings.TrimSpace(out.SSEPath) + out.TermsOfService = strings.TrimSpace(out.TermsOfService) + out.ContactName = strings.TrimSpace(out.ContactName) + out.ContactURL = strings.TrimSpace(out.ContactURL) + out.ContactEmail = strings.TrimSpace(out.ContactEmail) + out.LicenseName = strings.TrimSpace(out.LicenseName) + out.LicenseURL = strings.TrimSpace(out.LicenseURL) + out.ExternalDocsDescription = strings.TrimSpace(out.ExternalDocsDescription) + out.ExternalDocsURL = strings.TrimSpace(out.ExternalDocsURL) + out.CacheTTL = strings.TrimSpace(out.CacheTTL) + out.I18nDefaultLocale = strings.TrimSpace(out.I18nDefaultLocale) + out.Servers = slices.Clone(sb.Servers) + out.I18nSupportedLocales = slices.Clone(sb.I18nSupportedLocales) + out.AuthentikPublicPaths = normalisePublicPaths(sb.AuthentikPublicPaths) + out.SecuritySchemes = cloneSecuritySchemes(sb.SecuritySchemes) + + return &out +} + +// isPublicOperationPath reports whether an OpenAPI path should be documented +// as public because Authentik bypasses it in the running engine. +func (sb *SpecBuilder) isPublicOperationPath(path string) bool { + return isPublicPathForList(path, sb.effectiveAuthentikPublicPaths()) +} + +// hasAuthentikMetadata reports whether the spec carries any Authentik-related +// configuration worth surfacing. +func (sb *SpecBuilder) hasAuthentikMetadata() bool { + if sb == nil { + return false + } + + return strings.TrimSpace(sb.AuthentikIssuer) != "" || + strings.TrimSpace(sb.AuthentikClientID) != "" || + sb.AuthentikTrustedProxy || + len(sb.AuthentikPublicPaths) > 0 +} + +// makePathItemPublic strips auth-specific responses and marks every operation +// within the path item as public. +func makePathItemPublic(pathItem map[string]any) { + for _, rawOperation := range pathItem { + operation, ok := rawOperation.(map[string]any) + if !ok { + continue + } + + operation["security"] = []any{} + responses, ok := operation["responses"].(map[string]any) + if !ok { + continue + } + delete(responses, "401") + delete(responses, "403") + } +} + +// isPublicPathForList reports whether path should be documented as public +// when compared against a precomputed list of public paths. +func isPublicPathForList(path string, publicPaths []string) bool { + for _, publicPath := range publicPaths { + if isPublicPath(path, publicPath) { + return true + } + } + return false +} + +// documentedResponseHeaders converts route-specific response header metadata +// into OpenAPI header objects. +func documentedResponseHeaders(headers map[string]string) map[string]any { + if len(headers) == 0 { + return nil + } + + out := make(map[string]any, len(headers)) + for name, description := range headers { + name = strings.TrimSpace(name) + if name == "" { + continue + } + out[name] = map[string]any{ + "description": description, + "schema": map[string]any{ + "type": "string", + }, + } + } + if len(out) == 0 { + return nil + } + return out +} + +// standardResponseHeaders documents headers emitted by the response envelope +// middleware on all responses when request IDs are enabled. +func standardResponseHeaders() map[string]any { + return map[string]any{ + "X-Request-ID": map[string]any{ + "description": "Request identifier propagated from the client or generated by the server", + "schema": map[string]any{ + "type": "string", + }, + }, + } +} + +// mergeHeaders combines multiple OpenAPI header maps into one. +func mergeHeaders(sets ...map[string]any) map[string]any { + merged := make(map[string]any) + for _, set := range sets { + for name, value := range set { + merged[name] = value + } + } + return merged +} + +// operationID builds a stable OpenAPI operationId from the HTTP method and path. +// The generated identifier is lower snake_case and preserves path parameter names. +func operationID(method, path string, operationIDs map[string]int) string { + var b strings.Builder + b.Grow(len(method) + len(path) + 1) + lastUnderscore := false + + writeUnderscore := func() { + if b.Len() > 0 && !lastUnderscore { + b.WriteByte('_') + lastUnderscore = true + } + } + + appendToken := func(r rune) { + if unicode.IsLetter(r) || unicode.IsDigit(r) { + if unicode.IsUpper(r) { + r = unicode.ToLower(r) + } + b.WriteRune(r) + lastUnderscore = false + return + } + writeUnderscore() + } + + for _, r := range method { + appendToken(r) + } + writeUnderscore() + for _, r := range path { + switch r { + case '/': + writeUnderscore() + case '-': + writeUnderscore() + case '.': + writeUnderscore() + case ' ': + writeUnderscore() + default: + appendToken(r) + } + } + + out := strings.Trim(b.String(), "_") + if out == "" { + return "operation" + } + + if operationIDs == nil { + return out + } + + count := operationIDs[out] + operationIDs[out] = count + 1 + if count == 0 { + return out } + return out + "_" + strconv.Itoa(count+1) } diff --git a/openapi_test.go b/openapi_test.go index ed4a9b6..aa67c86 100644 --- a/openapi_test.go +++ b/openapi_test.go @@ -4,8 +4,10 @@ package api_test import ( "encoding/json" + "iter" "net/http" "testing" + "time" "github.com/gin-gonic/gin" @@ -17,6 +19,7 @@ import ( type specStubGroup struct { name string basePath string + hidden bool descs []api.RouteDescription } @@ -24,6 +27,7 @@ func (s *specStubGroup) Name() string { return s.name } func (s *specStubGroup) BasePath() string { return s.basePath } func (s *specStubGroup) RegisterRoutes(rg *gin.RouterGroup) {} func (s *specStubGroup) Describe() []api.RouteDescription { return s.descs } +func (s *specStubGroup) Hidden() bool { return s.hidden } type plainStubGroup struct{} @@ -31,6 +35,111 @@ func (plainStubGroup) Name() string { return "plain" } func (plainStubGroup) BasePath() string { return "/plain" } func (plainStubGroup) RegisterRoutes(rg *gin.RouterGroup) {} +type iterStubGroup struct { + name string + basePath string + descs []api.RouteDescription +} + +func (s *iterStubGroup) Name() string { return s.name } +func (s *iterStubGroup) BasePath() string { return s.basePath } +func (s *iterStubGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (s *iterStubGroup) Describe() []api.RouteDescription { return nil } +func (s *iterStubGroup) DescribeIter() iter.Seq[api.RouteDescription] { + return func(yield func(api.RouteDescription) bool) { + for _, rd := range s.descs { + if !yield(rd) { + return + } + } + } +} + +type iterNilFallbackGroup struct { + name string + basePath string + descs []api.RouteDescription +} + +func (s *iterNilFallbackGroup) Name() string { return s.name } +func (s *iterNilFallbackGroup) BasePath() string { return s.basePath } +func (s *iterNilFallbackGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (s *iterNilFallbackGroup) Describe() []api.RouteDescription { return s.descs } +func (s *iterNilFallbackGroup) DescribeIter() iter.Seq[api.RouteDescription] { + return nil +} + +type countingIterGroup struct { + name string + basePath string + descs []api.RouteDescription + describeCalls int +} + +func (s *countingIterGroup) Name() string { return s.name } +func (s *countingIterGroup) BasePath() string { return s.basePath } +func (s *countingIterGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (s *countingIterGroup) Describe() []api.RouteDescription { return nil } +func (s *countingIterGroup) DescribeIter() iter.Seq[api.RouteDescription] { + s.describeCalls++ + return func(yield func(api.RouteDescription) bool) { + for _, rd := range s.descs { + if !yield(rd) { + return + } + } + } +} + +type mutatingIterGroup struct { + name string + basePath string + descs []api.RouteDescription +} + +func (s *mutatingIterGroup) Name() string { return s.name } +func (s *mutatingIterGroup) BasePath() string { return s.basePath } +func (s *mutatingIterGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (s *mutatingIterGroup) Describe() []api.RouteDescription { return nil } +func (s *mutatingIterGroup) DescribeIter() iter.Seq[api.RouteDescription] { + return func(yield func(api.RouteDescription) bool) { + for i, rd := range s.descs { + if !yield(rd) { + return + } + s.descs[i].Response["mutated"] = true + s.descs[i].RequestBody["mutated"] = true + s.descs[i].Parameters[0].Schema["mutated"] = true + s.descs[i].ResponseHeaders["X-Mutated"] = "yes" + } + } +} + +type snapshottingGroup struct { + nameCalls int + basePathCalls int + descs []api.RouteDescription +} + +func (s *snapshottingGroup) Name() string { + s.nameCalls++ + if s.nameCalls == 1 { + return "alpha" + } + return "beta" +} + +func (s *snapshottingGroup) BasePath() string { + s.basePathCalls++ + if s.basePathCalls == 1 { + return "/alpha" + } + return "/beta" +} + +func (s *snapshottingGroup) RegisterRoutes(rg *gin.RouterGroup) {} +func (s *snapshottingGroup) Describe() []api.RouteDescription { return s.descs } + // ── SpecBuilder tests ───────────────────────────────────────────────────── func TestSpecBuilder_Good_EmptyGroups(t *testing.T) { @@ -54,12 +163,62 @@ func TestSpecBuilder_Good_EmptyGroups(t *testing.T) { if spec["openapi"] != "3.1.0" { t.Fatalf("expected openapi=3.1.0, got %v", spec["openapi"]) } + if spec["jsonSchemaDialect"] != "https://spec.openapis.org/oas/3.1/dialect/base" { + t.Fatalf("expected jsonSchemaDialect to use the OpenAPI 3.1 base dialect, got %v", spec["jsonSchemaDialect"]) + } // Verify /health path exists. paths := spec["paths"].(map[string]any) if _, ok := paths["/health"]; !ok { t.Fatal("expected /health path in spec") } + health := paths["/health"].(map[string]any)["get"].(map[string]any) + healthResponses := health["responses"].(map[string]any) + if _, ok := healthResponses["429"]; !ok { + t.Fatal("expected 429 response on /health") + } + if _, ok := healthResponses["504"]; !ok { + t.Fatal("expected 504 response on /health") + } + if _, ok := healthResponses["500"]; !ok { + t.Fatal("expected 500 response on /health") + } + rateLimit429 := healthResponses["429"].(map[string]any) + headers := rateLimit429["headers"].(map[string]any) + if _, ok := headers["Retry-After"]; !ok { + t.Fatal("expected Retry-After header on /health 429 response") + } + if _, ok := headers["X-Request-ID"]; !ok { + t.Fatal("expected X-Request-ID header on /health 429 response") + } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header on /health 429 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header on /health 429 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header on /health 429 response") + } + health504 := healthResponses["504"].(map[string]any) + health504Headers := health504["headers"].(map[string]any) + if _, ok := health504Headers["X-Request-ID"]; !ok { + t.Fatal("expected X-Request-ID header on /health 504 response") + } + if _, ok := health504Headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header on /health 504 response") + } + if _, ok := health504Headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header on /health 504 response") + } + if _, ok := health504Headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header on /health 504 response") + } + health200 := health["responses"].(map[string]any)["200"].(map[string]any) + health200Headers := health200["headers"].(map[string]any) + if _, ok := health200Headers["X-Cache"]; !ok { + t.Fatal("expected X-Cache header on /health 200 response") + } // Verify system tag exists. tags := spec["tags"].([]any) @@ -74,48 +233,2391 @@ func TestSpecBuilder_Good_EmptyGroups(t *testing.T) { if !found { t.Fatal("expected system tag in spec") } + + components := spec["components"].(map[string]any) + schemas := components["schemas"].(map[string]any) + if _, ok := schemas["Response"]; !ok { + t.Fatal("expected Response component schema in spec") + } + securitySchemes := components["securitySchemes"].(map[string]any) + bearerAuth := securitySchemes["bearerAuth"].(map[string]any) + if bearerAuth["type"] != "http" { + t.Fatalf("expected bearerAuth.type=http, got %v", bearerAuth["type"]) + } + if bearerAuth["scheme"] != "bearer" { + t.Fatalf("expected bearerAuth.scheme=bearer, got %v", bearerAuth["scheme"]) + } + if _, ok := spec["security"]; ok { + t.Fatal("expected no global security requirement in the document") + } + if _, ok := spec["x-swagger-enabled"]; ok { + t.Fatal("expected no swagger enabled flag in the document when swagger is disabled") + } + if _, ok := spec["x-graphql-enabled"]; ok { + t.Fatal("expected no graphql enabled flag in the document when graphql is disabled") + } } -func TestSpecBuilder_Good_WithDescribableGroup(t *testing.T) { +func TestSpecBuilder_Good_NilReceiverIsZeroValueSafe(t *testing.T) { + var sb *api.SpecBuilder + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if spec["openapi"] != "3.1.0" { + t.Fatalf("expected openapi=3.1.0, got %v", spec["openapi"]) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object, got %T", spec["paths"]) + } + if _, ok := paths["/health"]; !ok { + t.Fatal("expected /health path to be present") + } +} + +func TestSpecBuilder_Good_CustomSecuritySchemesAreMerged(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + SecuritySchemes: map[string]any{ + "apiKeyAuth": map[string]any{ + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + }, + }, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + components := spec["components"].(map[string]any) + schemes := components["securitySchemes"].(map[string]any) + + bearerAuth, ok := schemes["bearerAuth"].(map[string]any) + if !ok { + t.Fatal("expected default bearerAuth security scheme to remain present") + } + if bearerAuth["scheme"] != "bearer" { + t.Fatalf("expected bearerAuth scheme to stay bearer, got %v", bearerAuth["scheme"]) + } + + apiKeyAuth, ok := schemes["apiKeyAuth"].(map[string]any) + if !ok { + t.Fatal("expected custom apiKeyAuth security scheme to be merged") + } + if apiKeyAuth["type"] != "apiKey" { + t.Fatalf("expected apiKeyAuth.type=apiKey, got %v", apiKeyAuth["type"]) + } + if apiKeyAuth["in"] != "header" { + t.Fatalf("expected apiKeyAuth.in=header, got %v", apiKeyAuth["in"]) + } + if apiKeyAuth["name"] != "X-API-Key" { + t.Fatalf("expected apiKeyAuth.name=X-API-Key, got %v", apiKeyAuth["name"]) + } +} + +func TestSpecBuilder_Good_CommonResponseComponentsArePublished(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + components := spec["components"].(map[string]any) + responses := components["responses"].(map[string]any) + for _, name := range []string{ + "BadRequest", + "Unauthorized", + "Forbidden", + "RateLimitExceeded", + "GatewayTimeout", + "InternalServerError", + "Gone", + } { + if _, ok := responses[name]; !ok { + t.Fatalf("expected %s response component in spec", name) + } + } +} + +func TestSpecBuilder_Good_NormalisesMetadataAtBuild(t *testing.T) { + sb := &api.SpecBuilder{ + Title: " Test API ", + Summary: " ", + Description: " Trimmed description ", + Version: " 1.2.3 ", + TermsOfService: " https://example.com/terms ", + ContactName: " API Support ", + ContactURL: " https://example.com/support ", + ContactEmail: " support@example.com ", + LicenseName: " EUPL-1.2 ", + LicenseURL: " https://eupl.eu/1.2/en/ ", + ExternalDocsURL: " https://example.com/docs ", + ExternalDocsDescription: " Developer guide ", + SwaggerPath: " /docs/ ", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := spec["info"].(map[string]any) + if info["title"] != "Test API" { + t.Fatalf("expected trimmed title, got %v", info["title"]) + } + if _, ok := info["summary"]; ok { + t.Fatal("expected blank summary to be omitted") + } + if info["description"] != "Trimmed description" { + t.Fatalf("expected trimmed description, got %v", info["description"]) + } + if info["version"] != "1.2.3" { + t.Fatalf("expected trimmed version, got %v", info["version"]) + } + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected trimmed termsOfService, got %v", info["termsOfService"]) + } + contact := info["contact"].(map[string]any) + if contact["name"] != "API Support" { + t.Fatalf("expected trimmed contact name, got %v", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected trimmed contact url, got %v", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected trimmed contact email, got %v", contact["email"]) + } + license := info["license"].(map[string]any) + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected trimmed licence name, got %v", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected trimmed licence url, got %v", license["url"]) + } + externalDocs := spec["externalDocs"].(map[string]any) + if externalDocs["description"] != "Developer guide" { + t.Fatalf("expected trimmed external docs description, got %v", externalDocs["description"]) + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected trimmed external docs url, got %v", externalDocs["url"]) + } + if got := spec["x-swagger-ui-path"]; got != "/docs" { + t.Fatalf("expected trimmed swagger path, got %v", got) + } +} + +func TestSpecBuilder_Good_SwaggerUIPathExtension(t *testing.T) { sb := &api.SpecBuilder{ Title: "Test", - Description: "Test API", + Description: "Swagger path test", + Version: "1.0.0", + SwaggerPath: "/docs/", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-swagger-ui-path"]; got != "/docs" { + t.Fatalf("expected x-swagger-ui-path=/docs, got %v", got) + } +} + +func TestSpecBuilder_Good_CacheAndI18nExtensions(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Runtime config test", + Version: "1.0.0", + CacheEnabled: true, + CacheTTL: (5 * time.Minute).String(), + CacheMaxEntries: 42, + CacheMaxBytes: 8192, + I18nDefaultLocale: "en-GB", + I18nSupportedLocales: []string{"en-GB", "fr"}, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-cache-enabled"]; got != true { + t.Fatalf("expected x-cache-enabled=true, got %v", got) + } + if got := spec["x-cache-ttl"]; got != "5m0s" { + t.Fatalf("expected x-cache-ttl=5m0s, got %v", got) + } + if got := spec["x-cache-max-entries"]; got != float64(42) { + t.Fatalf("expected x-cache-max-entries=42, got %v", got) + } + if got := spec["x-cache-max-bytes"]; got != float64(8192) { + t.Fatalf("expected x-cache-max-bytes=8192, got %v", got) + } + + if got := spec["x-i18n-default-locale"]; got != "en-GB" { + t.Fatalf("expected x-i18n-default-locale=en-GB, got %v", got) + } + locales, ok := spec["x-i18n-supported-locales"].([]any) + if !ok { + t.Fatalf("expected x-i18n-supported-locales array, got %T", spec["x-i18n-supported-locales"]) + } + if len(locales) != 2 || locales[0] != "en-GB" || locales[1] != "fr" { + t.Fatalf("expected supported locales [en-GB fr], got %v", locales) + } +} + +func TestSpecBuilder_Good_OmitsNonPositiveCacheTTLExtension(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Cache TTL test", + Version: "1.0.0", + CacheTTL: "0s", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if _, ok := spec["x-cache-ttl"]; ok { + t.Fatal("expected non-positive cache TTL to be omitted from spec") + } +} + +func TestSpecBuilder_Good_GraphQLEndpoint(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "GraphQL test", Version: "1.0.0", + GraphQLPath: "/graphql", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + found := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "graphql" { + found = true + break + } + } + if !found { + t.Fatal("expected graphql tag in spec") + } + if _, ok := spec["x-graphql-playground"]; ok { + t.Fatal("expected x-graphql-playground to be omitted when playground is disabled") + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths["/graphql"].(map[string]any) + if !ok { + t.Fatal("expected /graphql path in spec") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_graphql" { + t.Fatalf("expected GraphQL GET operationId to be get_graphql, got %v", getOp["operationId"]) + } + getParams := getOp["parameters"].([]any) + if len(getParams) != 3 { + t.Fatalf("expected 3 GraphQL GET query parameters, got %d", len(getParams)) + } + if getParams[0].(map[string]any)["name"] != "query" { + t.Fatalf("expected first GraphQL GET parameter to be query, got %v", getParams[0]) + } + if getParams[0].(map[string]any)["required"] != true { + t.Fatal("expected GraphQL GET query parameter to be required") + } + + postOp := pathItem["post"].(map[string]any) + if postOp["operationId"] != "post_graphql" { + t.Fatalf("expected GraphQL operationId to be post_graphql, got %v", postOp["operationId"]) + } + + responses := postOp["responses"].(map[string]any) + successHeaders := responses["200"].(map[string]any)["headers"].(map[string]any) + if _, ok := successHeaders["X-Cache"]; !ok { + t.Fatal("expected X-Cache header on GraphQL 200 response") + } + + requestBody := postOp["requestBody"].(map[string]any) + schema := requestBody["content"].(map[string]any)["application/json"].(map[string]any)["schema"].(map[string]any) + properties := schema["properties"].(map[string]any) + if _, ok := properties["query"]; !ok { + t.Fatal("expected GraphQL request schema to include query field") + } + if _, ok := properties["variables"]; !ok { + t.Fatal("expected GraphQL request schema to include variables field") + } + if _, ok := properties["operationName"]; !ok { + t.Fatal("expected GraphQL request schema to include operationName field") + } +} + +func TestSpecBuilder_Good_GraphQLPlaygroundEndpoint(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + GraphQLPath: "/graphql", + GraphQLPlayground: true, + GraphQLPlaygroundPath: "/graphql/playground", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths["/graphql/playground"].(map[string]any) + if !ok { + t.Fatal("expected /graphql/playground path in spec") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_graphql_playground" { + t.Fatalf("expected playground operationId to be get_graphql_playground, got %v", getOp["operationId"]) + } + if got := spec["x-graphql-playground-path"]; got != "/graphql/playground" { + t.Fatalf("expected x-graphql-playground-path=/graphql/playground, got %v", got) + } + + responses := getOp["responses"].(map[string]any) + success := responses["200"].(map[string]any) + content := success["content"].(map[string]any) + if _, ok := content["text/html"]; !ok { + t.Fatal("expected text/html content type for GraphQL playground response") + } +} + +func TestSpecBuilder_Good_GraphQLPlaygroundDefaultsToGraphQLPath(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + GraphQLPlayground: true, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + if _, ok := paths["/graphql"].(map[string]any); !ok { + t.Fatal("expected default /graphql path when playground is enabled") + } + if _, ok := paths["/graphql/playground"].(map[string]any); !ok { + t.Fatal("expected default /graphql/playground path when playground is enabled") + } +} + +func TestSpecBuilder_Good_GraphQLPlaygroundDefaultsToGraphQLTag(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + GraphQLPlayground: true, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + found := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "graphql" { + found = true + break + } + } + if !found { + t.Fatal("expected graphql tag when playground enables the default GraphQL path") + } +} + +func TestSpecBuilder_Good_EnabledTransportsUseDefaultPaths(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + SwaggerEnabled: true, + GraphQLEnabled: true, + WSEnabled: true, + SSEEnabled: true, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-swagger-ui-path"]; got != "/swagger" { + t.Fatalf("expected default swagger path, got %v", got) + } + if got := spec["x-graphql-path"]; got != "/graphql" { + t.Fatalf("expected default graphql path, got %v", got) + } + if got := spec["x-ws-path"]; got != "/ws" { + t.Fatalf("expected default websocket path, got %v", got) + } + if got := spec["x-sse-path"]; got != "/events" { + t.Fatalf("expected default sse path, got %v", got) + } + + paths := spec["paths"].(map[string]any) + for _, path := range []string{"/graphql", "/ws", "/events"} { + if _, ok := paths[path].(map[string]any); !ok { + t.Fatalf("expected %s path in spec", path) + } + } + + tags := spec["tags"].([]any) + foundGraphQL := false + foundEvents := false + for _, tag := range tags { + tm := tag.(map[string]any) + switch tm["name"] { + case "graphql": + foundGraphQL = true + case "events": + foundEvents = true + } + } + if !foundGraphQL { + t.Fatal("expected graphql tag when GraphQL is enabled") + } + if !foundEvents { + t.Fatal("expected events tag when SSE is enabled") + } +} + +func TestSpecBuilder_Good_WebSocketEndpoint(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + WSPath: "/ws", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + found := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "system" { + found = true + break + } + } + if !found { + t.Fatal("expected system tag in spec") + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths["/ws"].(map[string]any) + if !ok { + t.Fatal("expected /ws path in spec") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_ws" { + t.Fatalf("expected WebSocket operationId to be get_ws, got %v", getOp["operationId"]) + } + if getOp["summary"] != "WebSocket connection" { + t.Fatalf("expected WebSocket summary, got %v", getOp["summary"]) + } + + responses := getOp["responses"].(map[string]any) + if _, ok := responses["101"]; !ok { + t.Fatal("expected 101 response on /ws") + } + if _, ok := responses["429"]; !ok { + t.Fatal("expected 429 response on /ws") + } +} + +func TestSpecBuilder_Good_ServerSentEventsEndpoint(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + SSEPath: "/events", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + found := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "events" { + found = true + break + } + } + if !found { + t.Fatal("expected events tag in spec") + } + + paths := spec["paths"].(map[string]any) + pathItem, ok := paths["/events"].(map[string]any) + if !ok { + t.Fatal("expected /events path in spec") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_events" { + t.Fatalf("expected SSE operationId to be get_events, got %v", getOp["operationId"]) + } + + params := getOp["parameters"].([]any) + if len(params) != 1 { + t.Fatalf("expected one SSE query parameter, got %d", len(params)) + } + param := params[0].(map[string]any) + if param["name"] != "channel" || param["in"] != "query" { + t.Fatalf("expected channel query parameter, got %+v", param) + } + + responses := getOp["responses"].(map[string]any) + success := responses["200"].(map[string]any) + content := success["content"].(map[string]any) + if _, ok := content["text/event-stream"]; !ok { + t.Fatal("expected text/event-stream content type for SSE response") + } + headers := success["headers"].(map[string]any) + for _, name := range []string{"Cache-Control", "Connection", "X-Accel-Buffering"} { + if _, ok := headers[name]; !ok { + t.Fatalf("expected %s header in SSE response", name) + } + } +} + +func TestSpecBuilder_Good_InfoIncludesLicenseMetadata(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Licensed test API", + Version: "1.2.3", + LicenseName: "EUPL-1.2", + LicenseURL: "https://eupl.eu/1.2/en/", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := spec["info"].(map[string]any) + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected license metadata in spec info") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected license name EUPL-1.2, got %v", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected license url to be preserved, got %v", license["url"]) + } +} + +func TestSpecBuilder_Good_InfoIncludesSummary(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Summary: "Concise API overview", + Description: "Summary test API", + Version: "1.2.3", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := spec["info"].(map[string]any) + if info["summary"] != "Concise API overview" { + t.Fatalf("expected summary to be preserved, got %v", info["summary"]) + } +} + +func TestSpecBuilder_Good_InfoIncludesContactMetadata(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Contact test API", + Version: "1.2.3", + ContactName: "API Support", + ContactURL: "https://example.com/support", + ContactEmail: "support@example.com", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := spec["info"].(map[string]any) + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in spec info") + } + if contact["name"] != "API Support" { + t.Fatalf("expected contact name API Support, got %v", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected contact url to be preserved, got %v", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected contact email to be preserved, got %v", contact["email"]) + } +} + +func TestSpecBuilder_Good_InfoIncludesTermsOfService(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Terms test API", + Version: "1.2.3", + TermsOfService: "https://example.com/terms", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := spec["info"].(map[string]any) + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService to be preserved, got %v", info["termsOfService"]) + } +} + +func TestSpecBuilder_Good_InfoIncludesExternalDocs(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "External docs test API", + Version: "1.2.3", + ExternalDocsDescription: "Developer guide", + ExternalDocsURL: "https://example.com/docs", + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + externalDocs, ok := spec["externalDocs"].(map[string]any) + if !ok { + t.Fatal("expected externalDocs metadata in spec") + } + if externalDocs["description"] != "Developer guide" { + t.Fatalf("expected externalDocs description to be preserved, got %v", externalDocs["description"]) + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected externalDocs url to be preserved, got %v", externalDocs["url"]) + } +} + +func TestSpecBuilder_Good_WithDescribableGroup(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Description: "Test API", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "items", + basePath: "/api/items", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/list", + Summary: "List items", + Tags: []string{"items"}, + Response: map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "string", + }, + }, + }, + { + Method: "POST", + Path: "/create", + Summary: "Create item", + Description: "Creates a new item", + Tags: []string{"items"}, + RequestBody: map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + }, + RequestExample: map[string]any{ + "name": "Widget", + }, + Response: map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + }, + }, + ResponseExample: map[string]any{ + "id": 42, + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + + // Verify GET /api/items/list exists. + listPath, ok := paths["/api/items/list"] + if !ok { + t.Fatal("expected /api/items/list path in spec") + } + getOp := listPath.(map[string]any)["get"] + if getOp == nil { + t.Fatal("expected GET operation on /api/items/list") + } + if getOp.(map[string]any)["summary"] != "List items" { + t.Fatalf("expected summary='List items', got %v", getOp.(map[string]any)["summary"]) + } + if getOp.(map[string]any)["operationId"] != "get_api_items_list" { + t.Fatalf("expected operationId='get_api_items_list', got %v", getOp.(map[string]any)["operationId"]) + } + + // Verify POST /api/items/create exists with request body. + createPath, ok := paths["/api/items/create"] + if !ok { + t.Fatal("expected /api/items/create path in spec") + } + postOp := createPath.(map[string]any)["post"] + if postOp == nil { + t.Fatal("expected POST operation on /api/items/create") + } + if postOp.(map[string]any)["summary"] != "Create item" { + t.Fatalf("expected summary='Create item', got %v", postOp.(map[string]any)["summary"]) + } + if postOp.(map[string]any)["operationId"] != "post_api_items_create" { + t.Fatalf("expected operationId='post_api_items_create', got %v", postOp.(map[string]any)["operationId"]) + } + if postOp.(map[string]any)["requestBody"] == nil { + t.Fatal("expected requestBody on POST /api/items/create") + } + requestBody := postOp.(map[string]any)["requestBody"].(map[string]any) + appJSON := requestBody["content"].(map[string]any)["application/json"].(map[string]any) + if appJSON["example"].(map[string]any)["name"] != "Widget" { + t.Fatalf("expected request example to be preserved, got %v", appJSON["example"]) + } + + responses := postOp.(map[string]any)["responses"].(map[string]any) + created := responses["200"].(map[string]any) + createdJSON := created["content"].(map[string]any)["application/json"].(map[string]any) + if createdJSON["example"].(map[string]any)["id"] != float64(42) { + t.Fatalf("expected response example to be preserved, got %v", createdJSON["example"]) + } +} + +func TestSpecBuilder_Good_DescribeIterGroup(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &iterStubGroup{ + name: "iter", + basePath: "/api/iter", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Iter status", + Tags: []string{"iter"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/iter/status"].(map[string]any)["get"].(map[string]any) + if op["summary"] != "Iter status" { + t.Fatalf("expected summary='Iter status', got %v", op["summary"]) + } + tags, ok := op["tags"].([]any) + if !ok || len(tags) != 1 || tags[0] != "iter" { + t.Fatalf("expected tags to be populated from DescribeIter, got %v", op["tags"]) + } +} + +func TestSpecBuilder_Good_DescribeIterSnapshotOnce(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &countingIterGroup{ + name: "counted", + basePath: "/api/count", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Counted status", + Tags: []string{"counted"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if group.describeCalls != 1 { + t.Fatalf("expected DescribeIter to be called once, got %d", group.describeCalls) + } + + op := spec["paths"].(map[string]any)["/api/count/status"].(map[string]any)["get"].(map[string]any) + if op["summary"] != "Counted status" { + t.Fatalf("expected summary='Counted status', got %v", op["summary"]) + } +} + +func TestSpecBuilder_Good_DescribeIterNilFallsBackToDescribe(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &iterNilFallbackGroup{ + name: "fallback-iter", + basePath: "/api/fallback-iter", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Fallback status", + Tags: []string{"fallback-iter"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/fallback-iter/status"].(map[string]any)["get"].(map[string]any) + if op["summary"] != "Fallback status" { + t.Fatalf("expected summary='Fallback status', got %v", op["summary"]) + } +} + +func TestSpecBuilder_Good_GroupMetadataIsSnapshottedOnce(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &snapshottingGroup{ + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Snapshot status", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + if _, ok := paths["/alpha/status"]; !ok { + t.Fatalf("expected snapshotted path /alpha/status, got %v", paths) + } + if _, ok := paths["/beta/status"]; ok { + t.Fatal("did not expect mutated base path to leak into the spec") + } + + tags := spec["tags"].([]any) + foundAlpha := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "alpha" { + foundAlpha = true + break + } + if tm["name"] == "beta" { + t.Fatal("did not expect mutated group name to leak into the spec") + } + } + if !foundAlpha { + t.Fatal("expected snapshotted group name in spec tags") + } + + op := paths["/alpha/status"].(map[string]any)["get"].(map[string]any) + opTags, ok := op["tags"].([]any) + if !ok || len(opTags) != 1 || opTags[0] != "alpha" { + t.Fatalf("expected snapshotted operation tag alpha, got %v", op["tags"]) + } + + if group.nameCalls != 1 { + t.Fatalf("expected Name to be called once, got %d", group.nameCalls) + } + if group.basePathCalls != 1 { + t.Fatalf("expected BasePath to be called once, got %d", group.basePathCalls) + } +} + +func TestSpecBuilder_Good_DeepClonesRouteMetadata(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &mutatingIterGroup{ + name: "alpha", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "POST", + Path: "/items", + Summary: "Create item", + Tags: []string{"items"}, + Parameters: []api.ParameterDescription{ + { + Name: "id", + In: "path", + Schema: map[string]any{ + "type": "string", + }, + }, + }, + RequestBody: map[string]any{ + "type": "object", + }, + Response: map[string]any{ + "type": "object", + }, + ResponseHeaders: map[string]string{ + "X-Test": "Original header", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/items"].(map[string]any)["post"].(map[string]any) + requestSchema := op["requestBody"].(map[string]any)["content"].(map[string]any)["application/json"].(map[string]any)["schema"].(map[string]any) + if _, ok := requestSchema["mutated"]; ok { + t.Fatal("did not expect request body mutation to leak into the spec") + } + + responses := op["responses"].(map[string]any) + resp201 := responses["200"].(map[string]any) + appJSON := resp201["content"].(map[string]any)["application/json"].(map[string]any) + responseSchema := appJSON["schema"].(map[string]any)["properties"].(map[string]any)["data"].(map[string]any) + if _, ok := responseSchema["mutated"]; ok { + t.Fatal("did not expect response mutation to leak into the spec") + } + + headers := resp201["headers"].(map[string]any) + if _, ok := headers["X-Mutated"]; ok { + t.Fatal("did not expect response header mutation to leak into the spec") + } + + params := op["parameters"].([]any) + pathParam := params[0].(map[string]any) + schema := pathParam["schema"].(map[string]any) + if _, ok := schema["mutated"]; ok { + t.Fatal("did not expect parameter schema mutation to leak into the spec") + } +} + +func TestSpecBuilder_Good_SecuredResponses(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "secure", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/private", + Summary: "Private endpoint", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + responses := spec["paths"].(map[string]any)["/api/private"].(map[string]any)["get"].(map[string]any)["responses"].(map[string]any) + if _, ok := responses["401"]; !ok { + t.Fatal("expected 401 response in secured operation") + } + if _, ok := responses["403"]; !ok { + t.Fatal("expected 403 response in secured operation") + } + if _, ok := responses["429"]; !ok { + t.Fatal("expected 429 response in secured operation") + } + if _, ok := responses["504"]; !ok { + t.Fatal("expected 504 response in secured operation") + } + if _, ok := responses["500"]; !ok { + t.Fatal("expected 500 response in secured operation") + } + rateLimit429 := responses["429"].(map[string]any) + headers := rateLimit429["headers"].(map[string]any) + if _, ok := headers["Retry-After"]; !ok { + t.Fatal("expected Retry-After header in secured operation 429 response") + } + if _, ok := headers["X-Request-ID"]; !ok { + t.Fatal("expected X-Request-ID header in secured operation 429 response") + } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header in secured operation 429 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header in secured operation 429 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header in secured operation 429 response") + } + for _, code := range []string{"400", "401", "403", "504", "500"} { + resp := responses[code].(map[string]any) + respHeaders := resp["headers"].(map[string]any) + if _, ok := respHeaders["X-Request-ID"]; !ok { + t.Fatalf("expected X-Request-ID header in secured operation %s response", code) + } + if _, ok := respHeaders["X-RateLimit-Limit"]; !ok { + t.Fatalf("expected X-RateLimit-Limit header in secured operation %s response", code) + } + if _, ok := respHeaders["X-RateLimit-Remaining"]; !ok { + t.Fatalf("expected X-RateLimit-Remaining header in secured operation %s response", code) + } + if _, ok := respHeaders["X-RateLimit-Reset"]; !ok { + t.Fatalf("expected X-RateLimit-Reset header in secured operation %s response", code) + } + } +} + +func TestSpecBuilder_Good_CustomSuccessStatusCode(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "items", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "POST", + Path: "/items", + Summary: "Create item", + StatusCode: http.StatusCreated, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + responses := spec["paths"].(map[string]any)["/api/items"].(map[string]any)["post"].(map[string]any)["responses"].(map[string]any) + if _, ok := responses["201"]; !ok { + t.Fatal("expected 201 response for created operation") + } + if _, ok := responses["200"]; ok { + t.Fatal("expected 200 response to be omitted when a custom success status is declared") + } + + created := responses["201"].(map[string]any) + if created["description"] != "Created" { + t.Fatalf("expected created description, got %v", created["description"]) + } + if created["content"] == nil { + t.Fatal("expected content for 201 response") + } +} + +func TestSpecBuilder_Good_NoContentSuccessStatusCode(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "items", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "DELETE", + Path: "/items/{id}", + Summary: "Delete item", + StatusCode: http.StatusNoContent, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + responses := spec["paths"].(map[string]any)["/api/items/{id}"].(map[string]any)["delete"].(map[string]any)["responses"].(map[string]any) + resp204 := responses["204"].(map[string]any) + if resp204["description"] != "No content" { + t.Fatalf("expected no-content description, got %v", resp204["description"]) + } + if _, ok := resp204["content"]; ok { + t.Fatal("expected no content block for 204 response") + } +} + +func TestSpecBuilder_Good_RouteSecurityOverrides(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "security", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/public", + Summary: "Public endpoint", + Security: []map[string][]string{}, + Response: map[string]any{ + "type": "object", + }, + }, + { + Method: "GET", + Path: "/scoped", + Summary: "Scoped endpoint", + Security: []map[string][]string{ + { + "bearerAuth": []string{}, + }, + { + "oauth2": []string{"read:items"}, + }, + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + + publicOp := paths["/api/public"].(map[string]any)["get"].(map[string]any) + publicSecurity, ok := publicOp["security"].([]any) + if !ok { + t.Fatalf("expected public security array, got %T", publicOp["security"]) + } + if len(publicSecurity) != 0 { + t.Fatalf("expected public route to have empty security requirement, got %v", publicSecurity) + } + publicResponses := publicOp["responses"].(map[string]any) + if _, ok := publicResponses["401"]; ok { + t.Fatal("expected public route to omit 401 response documentation") + } + if _, ok := publicResponses["403"]; ok { + t.Fatal("expected public route to omit 403 response documentation") + } + + scopedOp := paths["/api/scoped"].(map[string]any)["get"].(map[string]any) + scopedSecurity, ok := scopedOp["security"].([]any) + if !ok { + t.Fatalf("expected scoped security array, got %T", scopedOp["security"]) + } + if len(scopedSecurity) != 2 { + t.Fatalf("expected 2 security requirements, got %d", len(scopedSecurity)) + } + firstReq := scopedSecurity[0].(map[string]any) + if _, ok := firstReq["bearerAuth"]; !ok { + t.Fatalf("expected bearerAuth requirement, got %v", firstReq) + } + secondReq := scopedSecurity[1].(map[string]any) + if scopes, ok := secondReq["oauth2"].([]any); !ok || len(scopes) != 1 || scopes[0] != "read:items" { + t.Fatalf("expected oauth2 read:items requirement, got %v", secondReq["oauth2"]) + } +} + +func TestSpecBuilder_Good_AuthentikPublicPathsMakeMatchingOperationsPublic(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + AuthentikPublicPaths: []string{"/api/public"}, + } + + group := &specStubGroup{ + name: "security", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/public", + Summary: "Public endpoint", + Security: []map[string][]string{{"bearerAuth": []string{}}}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/public"].(map[string]any)["get"].(map[string]any) + security, ok := op["security"].([]any) + if !ok { + t.Fatalf("expected public route security array, got %T", op["security"]) + } + if len(security) != 0 { + t.Fatalf("expected public route to be documented without auth, got %v", security) + } + + responses := op["responses"].(map[string]any) + if _, ok := responses["401"]; ok { + t.Fatal("expected public route to omit 401 response documentation") + } + if _, ok := responses["403"]; ok { + t.Fatal("expected public route to omit 403 response documentation") + } + + paths := spec["x-authentik-public-paths"].([]any) + if len(paths) == 0 || paths[0] != "/health" { + t.Fatalf("expected public path extension to include /health first, got %v", paths) + } +} + +func TestSpecBuilder_Good_AuthentikPublicPathsMakeBuiltInEndpointsPublic(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + GraphQLEnabled: true, + GraphQLPath: "/graphql", + AuthentikPublicPaths: []string{"/graphql"}, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + pathItem := spec["paths"].(map[string]any)["/graphql"].(map[string]any) + for _, method := range []string{"get", "post"} { + op := pathItem[method].(map[string]any) + security, ok := op["security"].([]any) + if !ok { + t.Fatalf("expected %s security array, got %T", method, op["security"]) + } + if len(security) != 0 { + t.Fatalf("expected %s operation to be documented without auth, got %v", method, security) + } + + responses := op["responses"].(map[string]any) + if _, ok := responses["401"]; ok { + t.Fatalf("expected %s operation to omit 401 response documentation", method) + } + if _, ok := responses["403"]; ok { + t.Fatalf("expected %s operation to omit 403 response documentation", method) + } + } +} + +func TestSpecBuilder_Good_EnvelopeWrapping(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "data", + basePath: "/data", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/fetch", + Summary: "Fetch data", + Tags: []string{"data"}, + Response: map[string]any{ + "type": "object", + "properties": map[string]any{ + "value": map[string]any{"type": "string"}, + }, + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + fetchPath := paths["/data/fetch"].(map[string]any) + getOp := fetchPath["get"].(map[string]any) + responses := getOp["responses"].(map[string]any) + resp200 := responses["200"].(map[string]any) + headers := resp200["headers"].(map[string]any) + if _, ok := headers["X-Request-ID"]; !ok { + t.Fatal("expected X-Request-ID header on 200 response") + } + if _, ok := headers["X-RateLimit-Limit"]; !ok { + t.Fatal("expected X-RateLimit-Limit header on 200 response") + } + if _, ok := headers["X-RateLimit-Remaining"]; !ok { + t.Fatal("expected X-RateLimit-Remaining header on 200 response") + } + if _, ok := headers["X-RateLimit-Reset"]; !ok { + t.Fatal("expected X-RateLimit-Reset header on 200 response") + } + if _, ok := headers["X-Cache"]; !ok { + t.Fatal("expected X-Cache header on 200 response") + } + content := resp200["content"].(map[string]any) + appJSON := content["application/json"].(map[string]any) + schema := appJSON["schema"].(map[string]any) + if getOp["operationId"] != "get_data_fetch" { + t.Fatalf("expected operationId='get_data_fetch', got %v", getOp["operationId"]) + } + + // Verify envelope structure. + if schema["type"] != "object" { + t.Fatalf("expected schema type=object, got %v", schema["type"]) + } + + properties := schema["properties"].(map[string]any) + + // Verify success field. + success := properties["success"].(map[string]any) + if success["type"] != "boolean" { + t.Fatalf("expected success.type=boolean, got %v", success["type"]) + } + + // Verify data field contains the original response schema. + dataField := properties["data"].(map[string]any) + if dataField["type"] != "object" { + t.Fatalf("expected data.type=object, got %v", dataField["type"]) + } + dataProps := dataField["properties"].(map[string]any) + if dataProps["value"] == nil { + t.Fatal("expected data.properties.value to exist") + } + + // Verify required contains "success". + required := schema["required"].([]any) + foundSuccess := false + for _, r := range required { + if r == "success" { + foundSuccess = true + break + } + } + if !foundSuccess { + t.Fatal("expected 'success' in required array") + } +} + +func TestSpecBuilder_Good_OperationIDPreservesPathParams(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "users", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/users/{id}", + Summary: "Get user by id", + Tags: []string{"users"}, + Response: map[string]any{ + "type": "object", + }, + }, + { + Method: "GET", + Path: "/users/{name}", + Summary: "Get user by name", + Tags: []string{"users"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + byID := paths["/api/users/{id}"].(map[string]any)["get"].(map[string]any) + byName := paths["/api/users/{name}"].(map[string]any)["get"].(map[string]any) + + if byID["operationId"] != "get_api_users_id" { + t.Fatalf("expected operationId='get_api_users_id', got %v", byID["operationId"]) + } + if byName["operationId"] != "get_api_users_name" { + t.Fatalf("expected operationId='get_api_users_name', got %v", byName["operationId"]) + } + if byID["operationId"] == byName["operationId"] { + t.Fatal("expected unique operationId values for distinct path parameters") + } +} + +func TestSpecBuilder_Good_RequestBodyOnDelete(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "resources", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "DELETE", + Path: "/resources/{id}", + Summary: "Delete resource", + Tags: []string{"resources"}, + RequestBody: map[string]any{ + "type": "object", + "properties": map[string]any{ + "reason": map[string]any{"type": "string"}, + }, + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + deleteOp := paths["/api/resources/{id}"].(map[string]any)["delete"].(map[string]any) + if deleteOp["requestBody"] == nil { + t.Fatal("expected requestBody on DELETE /api/resources/{id}") + } +} + +func TestSpecBuilder_Good_RequestBodyOnHead(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "resources", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "HEAD", + Path: "/resources/{id}", + Summary: "Check resource", + Tags: []string{"resources"}, + RequestBody: map[string]any{ + "type": "object", + "properties": map[string]any{ + "include": map[string]any{"type": "string"}, + }, + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + headOp := paths["/api/resources/{id}"].(map[string]any)["head"].(map[string]any) + if headOp["requestBody"] == nil { + t.Fatal("expected requestBody on HEAD /api/resources/{id}") + } +} + +func TestSpecBuilder_Good_RequestExampleWithoutSchema(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "resources", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "POST", + Path: "/resources", + Summary: "Create resource", + Tags: []string{"resources"}, + RequestExample: map[string]any{ + "name": "Example resource", + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + postOp := spec["paths"].(map[string]any)["/api/resources"].(map[string]any)["post"].(map[string]any) + requestBody := postOp["requestBody"].(map[string]any) + appJSON := requestBody["content"].(map[string]any)["application/json"].(map[string]any) + + if appJSON["example"].(map[string]any)["name"] != "Example resource" { + t.Fatalf("expected request example to be preserved, got %v", appJSON["example"]) + } + + schema := appJSON["schema"].(map[string]any) + if len(schema) != 0 { + t.Fatalf("expected example-only request body to use an empty schema, got %v", schema) + } +} + +func TestSpecBuilder_Good_ResponseExampleWithoutSchema(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "resources", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/resources/{id}", + Summary: "Fetch resource", + Tags: []string{"resources"}, + ResponseExample: map[string]any{ + "name": "Example resource", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + getOp := spec["paths"].(map[string]any)["/api/resources/{id}"].(map[string]any)["get"].(map[string]any) + responses := getOp["responses"].(map[string]any) + resp200 := responses["200"].(map[string]any) + appJSON := resp200["content"].(map[string]any)["application/json"].(map[string]any) + + if appJSON["example"].(map[string]any)["name"] != "Example resource" { + t.Fatalf("expected response example to be preserved, got %v", appJSON["example"]) + } + + schema := appJSON["schema"].(map[string]any) + properties, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("expected envelope schema properties, got %v", schema) + } + if _, ok := properties["data"]; !ok { + t.Fatal("expected example-only response to expose an empty data schema") + } +} + +func TestSpecBuilder_Good_ResponseHeaders(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "downloads", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/exports/{id}", + Summary: "Download export", + ResponseHeaders: map[string]string{ + "Content-Disposition": "Download filename suggested by the server", + "X-Export-ID": "Identifier for the generated export", + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + responses := spec["paths"].(map[string]any)["/api/exports/{id}"].(map[string]any)["get"].(map[string]any)["responses"].(map[string]any) + resp200 := responses["200"].(map[string]any) + headers, ok := resp200["headers"].(map[string]any) + if !ok { + t.Fatalf("expected headers map, got %T", resp200["headers"]) + } + + header, ok := headers["Content-Disposition"].(map[string]any) + if !ok { + t.Fatal("expected Content-Disposition response header to be documented") + } + if header["description"] != "Download filename suggested by the server" { + t.Fatalf("expected header description to be preserved, got %v", header["description"]) + } + schema := header["schema"].(map[string]any) + if schema["type"] != "string" { + t.Fatalf("expected response header schema type string, got %v", schema["type"]) + } + + errorResp := responses["500"].(map[string]any) + errorHeaders, ok := errorResp["headers"].(map[string]any) + if !ok { + t.Fatalf("expected 500 headers map, got %T", errorResp["headers"]) + } + if _, ok := errorHeaders["Content-Disposition"]; !ok { + t.Fatal("expected route-specific headers on error responses too") + } + if _, ok := errorHeaders["X-Export-ID"]; !ok { + t.Fatal("expected route-specific headers on error responses too") + } +} + +func TestSpecBuilder_Good_PathParameters(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "users", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/users/{id}/{slug}", + Summary: "Get user", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/users/{id}/{slug}"].(map[string]any)["get"].(map[string]any) + params, ok := op["parameters"].([]any) + if !ok { + t.Fatalf("expected parameters array, got %T", op["parameters"]) + } + if len(params) != 2 { + t.Fatalf("expected 2 path parameters, got %d", len(params)) + } + + first := params[0].(map[string]any) + if first["name"] != "id" { + t.Fatalf("expected first parameter name=id, got %v", first["name"]) + } + if first["in"] != "path" { + t.Fatalf("expected first parameter in=path, got %v", first["in"]) + } + if required, ok := first["required"].(bool); !ok || !required { + t.Fatalf("expected first parameter to be required, got %v", first["required"]) + } + + second := params[1].(map[string]any) + if second["name"] != "slug" { + t.Fatalf("expected second parameter name=slug, got %v", second["name"]) + } +} + +func TestSpecBuilder_Good_PathNormalisation(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "users", + basePath: "/api/", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "users/{id}", + Summary: "Get user", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + if _, ok := paths["/api/users/{id}"]; !ok { + t.Fatalf("expected normalised path /api/users/{id}, got %v", paths) + } +} + +func TestSpecBuilder_Good_GinPathParameters(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "users", + basePath: "/api/", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "users/:id", + Summary: "Get user", + Response: map[string]any{ + "type": "object", + }, + }, + { + Method: "GET", + Path: "files/*path", + Summary: "Get file", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + + userOp := paths["/api/users/{id}"].(map[string]any)["get"].(map[string]any) + userParams := userOp["parameters"].([]any) + if len(userParams) != 1 { + t.Fatalf("expected 1 parameter for gin path, got %d", len(userParams)) + } + if userParams[0].(map[string]any)["name"] != "id" { + t.Fatalf("expected gin path parameter name=id, got %v", userParams[0]) + } + + fileOp := paths["/api/files/{path}"].(map[string]any)["get"].(map[string]any) + fileParams := fileOp["parameters"].([]any) + if len(fileParams) != 1 { + t.Fatalf("expected 1 parameter for wildcard path, got %d", len(fileParams)) + } + if fileParams[0].(map[string]any)["name"] != "path" { + t.Fatalf("expected wildcard parameter name=path, got %v", fileParams[0]) + } +} + +func TestSpecBuilder_Good_ExplicitParameters(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "users", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/users/{id}", + Summary: "Get user", + Parameters: []api.ParameterDescription{ + { + Name: "id", + In: "path", + Description: "User identifier", + Schema: map[string]any{ + "type": "string", + }, + }, + { + Name: "verbose", + In: "query", + Description: "Include verbose details", + Schema: map[string]any{ + "type": "boolean", + }, + }, + }, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + op := spec["paths"].(map[string]any)["/api/users/{id}"].(map[string]any)["get"].(map[string]any) + params, ok := op["parameters"].([]any) + if !ok { + t.Fatalf("expected parameters array, got %T", op["parameters"]) + } + if len(params) != 2 { + t.Fatalf("expected 2 parameters, got %d", len(params)) + } + + pathParam := params[0].(map[string]any) + if pathParam["name"] != "id" { + t.Fatalf("expected path parameter name=id, got %v", pathParam["name"]) + } + if pathParam["in"] != "path" { + t.Fatalf("expected path parameter in=path, got %v", pathParam["in"]) + } + if pathParam["description"] != "User identifier" { + t.Fatalf("expected merged path parameter description, got %v", pathParam["description"]) + } + + queryParam := params[1].(map[string]any) + if queryParam["name"] != "verbose" { + t.Fatalf("expected query parameter name=verbose, got %v", queryParam["name"]) + } + if queryParam["in"] != "query" { + t.Fatalf("expected query parameter in=query, got %v", queryParam["in"]) + } + if required, ok := queryParam["required"].(bool); !ok || required { + t.Fatalf("expected query parameter to be optional, got %v", queryParam["required"]) + } +} + +func TestSpecBuilder_Good_NonDescribableGroup(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + data, err := sb.Build([]api.RouteGroup{plainStubGroup{}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + // Verify plainStubGroup appears in tags. + tags := spec["tags"].([]any) + foundPlain := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "plain" { + foundPlain = true + break + } + } + if !foundPlain { + t.Fatal("expected 'plain' tag in spec for non-describable group") + } + + // Verify only /health exists in paths (plain group adds no paths). + paths := spec["paths"].(map[string]any) + if len(paths) != 1 { + t.Fatalf("expected 1 path (/health only), got %d", len(paths)) + } + if _, ok := paths["/health"]; !ok { + t.Fatal("expected /health path in spec") + } + health := paths["/health"].(map[string]any)["get"].(map[string]any) + if health["operationId"] != "get_health" { + t.Fatalf("expected operationId='get_health', got %v", health["operationId"]) + } + if security := health["security"]; security == nil { + t.Fatal("expected explicit public security override on /health") + } + if len(health["security"].([]any)) != 0 { + t.Fatalf("expected /health security to be empty, got %v", health["security"]) + } +} + +func TestSpecBuilder_Good_EmptyDescribableGroupStillAddsTag(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "empty", + basePath: "/api/empty", + descs: nil, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags := spec["tags"].([]any) + foundEmpty := false + for _, tag := range tags { + tm := tag.(map[string]any) + if tm["name"] == "empty" { + foundEmpty = true + break + } + } + if !foundEmpty { + t.Fatal("expected empty describable group to appear in spec tags") + } + + paths := spec["paths"].(map[string]any) + if len(paths) != 1 { + t.Fatalf("expected only /health path, got %d paths", len(paths)) + } + if _, ok := paths["/health"]; !ok { + t.Fatal("expected /health path in spec") + } +} + +func TestSpecBuilder_Good_DefaultTagsFromGroupName(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "fallback", + basePath: "/api/fallback", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Check status", + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + operation := spec["paths"].(map[string]any)["/api/fallback/status"].(map[string]any)["get"].(map[string]any) + tags, ok := operation["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", operation["tags"]) + } + if len(tags) != 1 || tags[0] != "fallback" { + t.Fatalf("expected fallback tag from group name, got %v", operation["tags"]) + } +} + +func TestSpecBuilder_Good_TagsAreSortedDeterministically(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } + + group := &specStubGroup{ + name: "gamma", + basePath: "/api/gamma", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Check status", + Tags: []string{"zeta", "alpha", "beta"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + tags, ok := spec["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", spec["tags"]) + } + + names := make([]string, 0, len(tags)) + for _, raw := range tags { + tag := raw.(map[string]any) + name, _ := tag["name"].(string) + names = append(names, name) + } + + expected := []string{"system", "alpha", "beta", "gamma", "zeta"} + if len(names) != len(expected) { + t.Fatalf("expected %d tags, got %d: %v", len(expected), len(names), names) + } + for i := range expected { + if names[i] != expected[i] { + t.Fatalf("expected tag order %v, got %v", expected, names) + } + } +} + +func TestSpecBuilder_Good_DeprecatedOperation(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", } group := &specStubGroup{ - name: "items", - basePath: "/api/items", + name: "legacy", + basePath: "/api/legacy", descs: []api.RouteDescription{ { - Method: "GET", - Path: "/list", - Summary: "List items", - Tags: []string{"items"}, - Response: map[string]any{ - "type": "array", - "items": map[string]any{ - "type": "string", - }, - }, - }, - { - Method: "POST", - Path: "/create", - Summary: "Create item", - Description: "Creates a new item", - Tags: []string{"items"}, - RequestBody: map[string]any{ - "type": "object", - "properties": map[string]any{ - "name": map[string]any{"type": "string"}, - }, - }, + Method: "GET", + Path: "/status", + Summary: "Check legacy status", + Deprecated: true, + SunsetDate: "2025-06-01", + Replacement: "/api/v2/status", Response: map[string]any{ "type": "object", - "properties": map[string]any{ - "id": map[string]any{"type": "integer"}, - }, }, }, }, @@ -131,58 +2633,72 @@ func TestSpecBuilder_Good_WithDescribableGroup(t *testing.T) { t.Fatalf("invalid JSON: %v", err) } - paths := spec["paths"].(map[string]any) - - // Verify GET /api/items/list exists. - listPath, ok := paths["/api/items/list"] + op := spec["paths"].(map[string]any)["/api/legacy/status"].(map[string]any)["get"].(map[string]any) + deprecated, ok := op["deprecated"].(bool) if !ok { - t.Fatal("expected /api/items/list path in spec") + t.Fatalf("expected deprecated boolean, got %T", op["deprecated"]) } - getOp := listPath.(map[string]any)["get"] - if getOp == nil { - t.Fatal("expected GET operation on /api/items/list") + if !deprecated { + t.Fatal("expected deprecated operation to be marked true") } - if getOp.(map[string]any)["summary"] != "List items" { - t.Fatalf("expected summary='List items', got %v", getOp.(map[string]any)["summary"]) + + responses := op["responses"].(map[string]any) + success := responses["200"].(map[string]any) + headers := success["headers"].(map[string]any) + for _, name := range []string{"Deprecation", "Sunset", "Link", "X-API-Warn"} { + if _, ok := headers[name]; !ok { + t.Fatalf("expected deprecation header %q in operation response headers", name) + } } - // Verify POST /api/items/create exists with request body. - createPath, ok := paths["/api/items/create"] + gone, ok := responses["410"].(map[string]any) if !ok { - t.Fatal("expected /api/items/create path in spec") + t.Fatal("expected 410 Gone response for sunsetted operation") } - postOp := createPath.(map[string]any)["post"] - if postOp == nil { - t.Fatal("expected POST operation on /api/items/create") + if got := gone["description"]; got != "Gone" { + t.Fatalf("expected 410 response description Gone, got %v", got) } - if postOp.(map[string]any)["summary"] != "Create item" { - t.Fatalf("expected summary='Create item', got %v", postOp.(map[string]any)["summary"]) + goneHeaders, ok := gone["headers"].(map[string]any) + if !ok { + t.Fatalf("expected 410 response headers map, got %T", gone["headers"]) } - if postOp.(map[string]any)["requestBody"] == nil { - t.Fatal("expected requestBody on POST /api/items/create") + for _, name := range []string{"Deprecation", "Sunset", "Link", "X-API-Warn"} { + if _, ok := goneHeaders[name]; !ok { + t.Fatalf("expected deprecation header %q in 410 response headers", name) + } + } + + components := spec["components"].(map[string]any) + headerComponents := components["headers"].(map[string]any) + for _, name := range []string{"deprecation", "sunset", "link", "xapiwarn"} { + if _, ok := headerComponents[name]; !ok { + t.Fatalf("expected reusable header component %q", name) + } + } + + deprecationHeader := headers["Deprecation"].(map[string]any) + if got := deprecationHeader["$ref"]; got != "#/components/headers/deprecation" { + t.Fatalf("expected Deprecation header to reference shared component, got %v", got) } } -func TestSpecBuilder_Good_EnvelopeWrapping(t *testing.T) { +func TestSpecBuilder_Good_BlankTagsAreIgnored(t *testing.T) { sb := &api.SpecBuilder{ Title: "Test", Version: "1.0.0", } group := &specStubGroup{ - name: "data", - basePath: "/data", + name: " ", + basePath: "/api/blank", descs: []api.RouteDescription{ { Method: "GET", - Path: "/fetch", - Summary: "Fetch data", - Tags: []string{"data"}, + Path: "/status", + Summary: "Check status", + Tags: []string{"", " ", "data", "data"}, Response: map[string]any{ "type": "object", - "properties": map[string]any{ - "value": map[string]any{"type": "string"}, - }, }, }, }, @@ -198,59 +2714,124 @@ func TestSpecBuilder_Good_EnvelopeWrapping(t *testing.T) { t.Fatalf("invalid JSON: %v", err) } - paths := spec["paths"].(map[string]any) - fetchPath := paths["/data/fetch"].(map[string]any) - getOp := fetchPath["get"].(map[string]any) - responses := getOp["responses"].(map[string]any) - resp200 := responses["200"].(map[string]any) - content := resp200["content"].(map[string]any) - appJSON := content["application/json"].(map[string]any) - schema := appJSON["schema"].(map[string]any) + tags := spec["tags"].([]any) + var foundData bool + for _, raw := range tags { + tag := raw.(map[string]any) + name, _ := tag["name"].(string) + if name == "" { + t.Fatal("expected blank tag names to be ignored") + } + if name == "data" { + foundData = true + } + } + if !foundData { + t.Fatal("expected data tag to be retained") + } - // Verify envelope structure. - if schema["type"] != "object" { - t.Fatalf("expected schema type=object, got %v", schema["type"]) + op := spec["paths"].(map[string]any)["/api/blank/status"].(map[string]any)["get"].(map[string]any) + opTags, ok := op["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", op["tags"]) + } + if len(opTags) != 1 || opTags[0] != "data" { + t.Fatalf("expected operation tags to be cleaned to [data], got %v", opTags) } +} - properties := schema["properties"].(map[string]any) +func TestSpecBuilder_Good_BlankRouteTagsFallBackToGroupName(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + } - // Verify success field. - success := properties["success"].(map[string]any) - if success["type"] != "boolean" { - t.Fatalf("expected success.type=boolean, got %v", success["type"]) + group := &specStubGroup{ + name: "fallback", + basePath: "/api/fallback", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Check status", + Tags: []string{"", " "}, + Response: map[string]any{ + "type": "object", + }, + }, + }, } - // Verify data field contains the original response schema. - dataField := properties["data"].(map[string]any) - if dataField["type"] != "object" { - t.Fatalf("expected data.type=object, got %v", dataField["type"]) + data, err := sb.Build([]api.RouteGroup{group}) + if err != nil { + t.Fatalf("unexpected error: %v", err) } - dataProps := dataField["properties"].(map[string]any) - if dataProps["value"] == nil { - t.Fatal("expected data.properties.value to exist") + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) } - // Verify required contains "success". - required := schema["required"].([]any) - foundSuccess := false - for _, r := range required { - if r == "success" { - foundSuccess = true - break - } + op := spec["paths"].(map[string]any)["/api/fallback/status"].(map[string]any)["get"].(map[string]any) + tags, ok := op["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", op["tags"]) } - if !foundSuccess { - t.Fatal("expected 'success' in required array") + if len(tags) != 1 || tags[0] != "fallback" { + t.Fatalf("expected blank route tags to fall back to group name, got %v", tags) } } -func TestSpecBuilder_Good_NonDescribableGroup(t *testing.T) { +func TestSpecBuilder_Good_HiddenRoutesAreOmitted(t *testing.T) { sb := &api.SpecBuilder{ Title: "Test", Version: "1.0.0", } - data, err := sb.Build([]api.RouteGroup{plainStubGroup{}}) + visible := &specStubGroup{ + name: "visible", + basePath: "/api", + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/public", + Summary: "Public endpoint", + Tags: []string{"public"}, + Response: map[string]any{ + "type": "object", + }, + }, + { + Method: "GET", + Path: "/internal", + Summary: "Internal endpoint", + Tags: []string{"internal"}, + Hidden: true, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + hidden := &specStubGroup{ + name: "hidden-group", + basePath: "/api/internal", + hidden: true, + descs: []api.RouteDescription{ + { + Method: "GET", + Path: "/status", + Summary: "Hidden group endpoint", + Tags: []string{"hidden"}, + Response: map[string]any{ + "type": "object", + }, + }, + }, + } + + data, err := sb.Build([]api.RouteGroup{visible, hidden}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -260,27 +2841,54 @@ func TestSpecBuilder_Good_NonDescribableGroup(t *testing.T) { t.Fatalf("invalid JSON: %v", err) } - // Verify plainStubGroup appears in tags. + paths := spec["paths"].(map[string]any) + if _, ok := paths["/api/public"]; !ok { + t.Fatal("expected visible route to remain in the spec") + } + if _, ok := paths["/api/internal"]; ok { + t.Fatal("did not expect hidden route to appear in the spec") + } + if _, ok := paths["/api/internal/status"]; ok { + t.Fatal("did not expect hidden group routes to appear in the spec") + } + tags := spec["tags"].([]any) - foundPlain := false - for _, tag := range tags { - tm := tag.(map[string]any) - if tm["name"] == "plain" { - foundPlain = true - break + foundPublic := false + foundInternal := false + foundHidden := false + foundVisibleGroup := false + foundHiddenGroup := false + for _, raw := range tags { + tag := raw.(map[string]any) + name, _ := tag["name"].(string) + switch name { + case "public": + foundPublic = true + case "internal": + foundInternal = true + case "hidden": + foundHidden = true + case "visible": + foundVisibleGroup = true + case "hidden-group": + foundHiddenGroup = true } } - if !foundPlain { - t.Fatal("expected 'plain' tag in spec for non-describable group") - } - // Verify only /health exists in paths (plain group adds no paths). - paths := spec["paths"].(map[string]any) - if len(paths) != 1 { - t.Fatalf("expected 1 path (/health only), got %d", len(paths)) + if !foundPublic { + t.Fatal("expected public tag to remain in the spec") } - if _, ok := paths["/health"]; !ok { - t.Fatal("expected /health path in spec") + if !foundVisibleGroup { + t.Fatal("expected visible group tag to remain in the spec") + } + if foundInternal { + t.Fatal("did not expect hidden route tag to appear in the spec") + } + if foundHidden { + t.Fatal("did not expect hidden group route tag to appear in the spec") + } + if foundHiddenGroup { + t.Fatal("did not expect hidden group tag to appear in the spec") } } @@ -336,6 +2944,25 @@ func TestSpecBuilder_Good_ToolBridgeIntegration(t *testing.T) { t.Fatalf("invalid JSON: %v", err) } + tags, ok := spec["tags"].([]any) + if !ok { + t.Fatalf("expected tags array, got %T", spec["tags"]) + } + expectedTags := map[string]bool{ + "system": true, + "tools": true, + "files": true, + "metrics": true, + } + for _, raw := range tags { + tag := raw.(map[string]any) + name, _ := tag["name"].(string) + delete(expectedTags, name) + } + if len(expectedTags) != 0 { + t.Fatalf("expected declared tags to include system, tools, files, and metrics, missing %v", expectedTags) + } + paths := spec["paths"].(map[string]any) // Verify POST /tools/file_read exists. @@ -350,6 +2977,9 @@ func TestSpecBuilder_Good_ToolBridgeIntegration(t *testing.T) { if postOp.(map[string]any)["summary"] != "Read a file from disk" { t.Fatalf("expected summary='Read a file from disk', got %v", postOp.(map[string]any)["summary"]) } + if postOp.(map[string]any)["operationId"] != "post_tools_file_read" { + t.Fatalf("expected operationId='post_tools_file_read', got %v", postOp.(map[string]any)["operationId"]) + } // Verify POST /tools/metrics_query exists. metricsPath, ok := paths["/tools/metrics_query"] @@ -363,6 +2993,9 @@ func TestSpecBuilder_Good_ToolBridgeIntegration(t *testing.T) { if metricsOp.(map[string]any)["summary"] != "Query metrics data" { t.Fatalf("expected summary='Query metrics data', got %v", metricsOp.(map[string]any)["summary"]) } + if metricsOp.(map[string]any)["operationId"] != "post_tools_metrics_query" { + t.Fatalf("expected operationId='post_tools_metrics_query', got %v", metricsOp.(map[string]any)["operationId"]) + } // Verify request body is present on both (both are POST with InputSchema). if postOp.(map[string]any)["requestBody"] == nil { @@ -401,3 +3034,117 @@ func TestSpecBuilder_Bad_InfoFields(t *testing.T) { t.Fatalf("expected version=1.0.0, got %v", info["version"]) } } + +func TestSpecBuilder_Good_Servers(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + Servers: []string{ + " https://api.example.com ", + "/", + "", + "https://api.example.com", + }, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + servers, ok := spec["servers"].([]any) + if !ok { + t.Fatalf("expected servers array, got %T", spec["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 normalised servers, got %d", len(servers)) + } + + first := servers[0].(map[string]any) + if first["url"] != "https://api.example.com" { + t.Fatalf("expected first server url=%q, got %v", "https://api.example.com", first["url"]) + } + second := servers[1].(map[string]any) + if second["url"] != "/" { + t.Fatalf("expected second server url=%q, got %v", "/", second["url"]) + } +} + +func TestSpecBuilder_Good_ServersCollapseTrailingSlashes(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + Servers: []string{ + "https://api.example.com/", + "https://api.example.com", + "/api/", + "/api", + }, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + servers, ok := spec["servers"].([]any) + if !ok { + t.Fatalf("expected servers array, got %T", spec["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 collapsed servers, got %d", len(servers)) + } + + first := servers[0].(map[string]any) + if first["url"] != "https://api.example.com" { + t.Fatalf("expected first server url=%q, got %v", "https://api.example.com", first["url"]) + } + second := servers[1].(map[string]any) + if second["url"] != "/api" { + t.Fatalf("expected second server url=%q, got %v", "/api", second["url"]) + } +} + +func TestSpecBuilder_Good_RuntimeDebugEndpointsDocumentRateLimitHeaders(t *testing.T) { + sb := &api.SpecBuilder{ + Title: "Test", + Version: "1.0.0", + PprofEnabled: true, + ExpvarEnabled: true, + } + + data, err := sb.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := spec["paths"].(map[string]any) + for _, path := range []string{"/debug/pprof", "/debug/vars"} { + item := paths[path].(map[string]any) + op := item["get"].(map[string]any) + for _, code := range []string{"200", "401", "403"} { + resp := op["responses"].(map[string]any)[code].(map[string]any) + headers := resp["headers"].(map[string]any) + for _, name := range []string{"X-Request-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"} { + if _, ok := headers[name]; !ok { + t.Fatalf("expected %s header on %s %s response", name, path, code) + } + } + } + } +} diff --git a/options.go b/options.go index bdf3f66..6f10798 100644 --- a/options.go +++ b/options.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "slices" + "strings" "time" "github.com/99designs/gqlgen/graphql" @@ -26,9 +27,17 @@ import ( ) // Option configures an Engine during construction. +// +// Example: +// +// engine, _ := api.New(api.WithAddr(":8080")) type Option func(*Engine) // WithAddr sets the listen address for the server. +// +// Example: +// +// api.New(api.WithAddr(":8443")) func WithAddr(addr string) Option { return func(e *Engine) { e.addr = addr @@ -36,26 +45,57 @@ func WithAddr(addr string) Option { } // WithBearerAuth adds bearer token authentication middleware. -// Requests to /health and paths starting with /swagger are exempt. +// Requests to /health and the Swagger UI path are exempt. +// +// Example: +// +// api.New(api.WithBearerAuth("secret")) func WithBearerAuth(token string) Option { return func(e *Engine) { - skip := []string{"/health", "/swagger"} - e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, skip)) + e.middlewares = append(e.middlewares, bearerAuthMiddleware(token, func() []string { + skip := []string{"/health"} + if swaggerPath := resolveSwaggerPath(e.swaggerPath); swaggerPath != "" { + skip = append(skip, swaggerPath) + } + return skip + })) } } // WithRequestID adds middleware that assigns an X-Request-ID to every response. // Client-provided IDs are preserved; otherwise a random hex ID is generated. +// +// Example: +// +// api.New(api.WithRequestID()) func WithRequestID() Option { return func(e *Engine) { e.middlewares = append(e.middlewares, requestIDMiddleware()) } } +// WithResponseMeta attaches request metadata to JSON envelope responses. +// It preserves any existing pagination metadata and merges in request_id +// and duration when available from the request context. Combine it with +// WithRequestID() to populate both fields automatically. +// +// Example: +// +// api.New(api.WithRequestID(), api.WithResponseMeta()) +func WithResponseMeta() Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, responseMetaMiddleware()) + } +} + // WithCORS configures Cross-Origin Resource Sharing via gin-contrib/cors. // Pass "*" to allow all origins, or supply specific origin URLs. // Standard methods (GET, POST, PUT, PATCH, DELETE, OPTIONS) and common // headers (Authorization, Content-Type, X-Request-ID) are permitted. +// +// Example: +// +// api.New(api.WithCORS("*")) func WithCORS(allowOrigins ...string) Option { return func(e *Engine) { cfg := cors.Config{ @@ -76,6 +116,10 @@ func WithCORS(allowOrigins ...string) Option { } // WithMiddleware appends arbitrary Gin middleware to the engine. +// +// Example: +// +// api.New(api.WithMiddleware(loggingMiddleware)) func WithMiddleware(mw ...gin.HandlerFunc) Option { return func(e *Engine) { e.middlewares = append(e.middlewares, mw...) @@ -85,6 +129,10 @@ func WithMiddleware(mw ...gin.HandlerFunc) Option { // WithStatic serves static files from the given root directory at urlPrefix. // Directory listing is disabled; only individual files are served. // Internally this uses gin-contrib/static as Gin middleware. +// +// Example: +// +// api.New(api.WithStatic("/assets", "./public")) func WithStatic(urlPrefix, root string) Option { return func(e *Engine) { e.middlewares = append(e.middlewares, static.Serve(urlPrefix, static.LocalFile(root, false))) @@ -92,33 +140,215 @@ func WithStatic(urlPrefix, root string) Option { } // WithWSHandler registers a WebSocket handler at GET /ws. +// Use WithWSPath to customise the route before mounting the handler. // Typically this wraps a go-ws Hub.Handler(). +// +// Example: +// +// api.New(api.WithWSHandler(http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))) func WithWSHandler(h http.Handler) Option { return func(e *Engine) { e.wsHandler = h } } +// WithWSPath sets a custom URL path for the WebSocket endpoint. +// The default path is "/ws". +// +// Example: +// +// api.New(api.WithWSPath("/socket")) +func WithWSPath(path string) Option { + return func(e *Engine) { + e.wsPath = normaliseWSPath(path) + } +} + // WithAuthentik adds Authentik forward-auth middleware that extracts user // identity from X-authentik-* headers set by a trusted reverse proxy. // The middleware is permissive: unauthenticated requests are allowed through. +// +// Example: +// +// api.New(api.WithAuthentik(api.AuthentikConfig{TrustedProxy: true})) func WithAuthentik(cfg AuthentikConfig) Option { return func(e *Engine) { - e.middlewares = append(e.middlewares, authentikMiddleware(cfg)) + snapshot := cloneAuthentikConfig(cfg) + e.authentikConfig = snapshot + e.middlewares = append(e.middlewares, authentikMiddleware(snapshot, func() []string { + return []string{resolveSwaggerPath(e.swaggerPath)} + })) + } +} + +// WithSunset adds deprecation headers to every response. +// The middleware appends Deprecation, optional Sunset, optional Link, and +// X-API-Warn headers without clobbering any existing header values. Use it to +// deprecate an entire route group or API version. +// +// Example: +// +// api.New(api.WithSunset("2026-12-31", "https://api.example.com/v2")) +func WithSunset(sunsetDate, replacement string) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, ApiSunset(sunsetDate, replacement)) } } -// WithSwagger enables the Swagger UI at /swagger/. +// WithSwagger enables the Swagger UI at /swagger/ by default. // The title, description, and version populate the OpenAPI info block. +// Use WithSwaggerSummary() to set the optional info.summary field. +// +// Example: +// +// api.New(api.WithSwagger("Service", "Public API", "1.0.0")) func WithSwagger(title, description, version string) Option { return func(e *Engine) { - e.swaggerTitle = title - e.swaggerDesc = description - e.swaggerVersion = version + e.swaggerTitle = strings.TrimSpace(title) + e.swaggerDesc = strings.TrimSpace(description) + e.swaggerVersion = strings.TrimSpace(version) e.swaggerEnabled = true } } +// WithSwaggerSummary adds the OpenAPI info.summary field to generated specs. +// +// Example: +// +// api.WithSwaggerSummary("Service overview") +func WithSwaggerSummary(summary string) Option { + return func(e *Engine) { + if summary = strings.TrimSpace(summary); summary != "" { + e.swaggerSummary = summary + } + } +} + +// WithSwaggerPath sets a custom URL path for the Swagger UI. +// The default path is "/swagger". +// +// Example: +// +// api.New(api.WithSwaggerPath("/docs")) +func WithSwaggerPath(path string) Option { + return func(e *Engine) { + e.swaggerPath = normaliseSwaggerPath(path) + } +} + +// WithSwaggerTermsOfService adds the terms of service URL to the generated Swagger spec. +// Empty strings are ignored. +// +// Example: +// +// api.WithSwaggerTermsOfService("https://example.com/terms") +func WithSwaggerTermsOfService(url string) Option { + return func(e *Engine) { + if url = strings.TrimSpace(url); url != "" { + e.swaggerTermsOfService = url + } + } +} + +// WithSwaggerContact adds contact metadata to the generated Swagger spec. +// Empty fields are ignored. Multiple calls replace the previous contact data. +// +// Example: +// +// api.WithSwaggerContact("API Support", "https://example.com/support", "support@example.com") +func WithSwaggerContact(name, url, email string) Option { + return func(e *Engine) { + if name = strings.TrimSpace(name); name != "" { + e.swaggerContactName = name + } + if url = strings.TrimSpace(url); url != "" { + e.swaggerContactURL = url + } + if email = strings.TrimSpace(email); email != "" { + e.swaggerContactEmail = email + } + } +} + +// WithSwaggerServers adds OpenAPI server metadata to the generated Swagger spec. +// Empty strings are ignored. Multiple calls append and normalise the combined +// server list so callers can compose metadata across options. +// +// Example: +// +// api.WithSwaggerServers("https://api.example.com", "https://docs.example.com") +func WithSwaggerServers(servers ...string) Option { + return func(e *Engine) { + e.swaggerServers = normaliseServers(append(e.swaggerServers, servers...)) + } +} + +// WithSwaggerLicense adds licence metadata to the generated Swagger spec. +// Pass both a name and URL to populate the OpenAPI info block consistently. +// +// Example: +// +// api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/") +func WithSwaggerLicense(name, url string) Option { + return func(e *Engine) { + if name = strings.TrimSpace(name); name != "" { + e.swaggerLicenseName = name + } + if url = strings.TrimSpace(url); url != "" { + e.swaggerLicenseURL = url + } + } +} + +// WithSwaggerSecuritySchemes merges custom OpenAPI security schemes into the +// generated Swagger spec. Existing schemes are preserved unless the new map +// defines the same key, in which case the later definition wins. +// +// Example: +// +// api.WithSwaggerSecuritySchemes(map[string]any{ +// "apiKeyAuth": map[string]any{ +// "type": "apiKey", +// "in": "header", +// "name": "X-API-Key", +// }, +// }) +func WithSwaggerSecuritySchemes(schemes map[string]any) Option { + return func(e *Engine) { + if len(schemes) == 0 { + return + } + if e.swaggerSecuritySchemes == nil { + e.swaggerSecuritySchemes = make(map[string]any, len(schemes)) + } + for name, scheme := range schemes { + name = strings.TrimSpace(name) + if name == "" || scheme == nil { + continue + } + e.swaggerSecuritySchemes[name] = cloneOpenAPIValue(scheme) + } + } +} + +// WithSwaggerExternalDocs adds top-level external documentation metadata to +// the generated Swagger spec. +// Empty URLs are ignored; the description is optional. +// +// Example: +// +// api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs") +func WithSwaggerExternalDocs(description, url string) Option { + return func(e *Engine) { + if description = strings.TrimSpace(description); description != "" { + e.swaggerExternalDocsDescription = description + } + if url = strings.TrimSpace(url); url != "" { + e.swaggerExternalDocsURL = url + } + } +} + // WithPprof enables Go runtime profiling endpoints at /debug/pprof/. // The standard pprof handlers (index, cmdline, profile, symbol, trace, // allocs, block, goroutine, heap, mutex, threadcreate) are registered @@ -126,6 +356,10 @@ func WithSwagger(title, description, version string) Option { // // WARNING: pprof exposes sensitive runtime data and should only be // enabled in development or behind authentication in production. +// +// Example: +// +// api.New(api.WithPprof()) func WithPprof() Option { return func(e *Engine) { e.pprofEnabled = true @@ -140,6 +374,10 @@ func WithPprof() Option { // WARNING: expvar exposes runtime internals (memory allocation, // goroutine counts, command-line arguments) and should only be // enabled in development or behind authentication in production. +// +// Example: +// +// api.New(api.WithExpvar()) func WithExpvar() Option { return func(e *Engine) { e.expvarEnabled = true @@ -151,6 +389,10 @@ func WithExpvar() Option { // X-Content-Type-Options nosniff, and Referrer-Policy strict-origin-when-cross-origin. // SSL redirect is not enabled so the middleware works behind a reverse proxy // that terminates TLS. +// +// Example: +// +// api.New(api.WithSecure()) func WithSecure() Option { return func(e *Engine) { e.middlewares = append(e.middlewares, secure.New(secure.Config{ @@ -167,6 +409,10 @@ func WithSecure() Option { // WithGzip adds gzip response compression middleware via gin-contrib/gzip. // An optional compression level may be supplied (e.g. gzip.BestSpeed, // gzip.BestCompression). If omitted, gzip.DefaultCompression is used. +// +// Example: +// +// api.New(api.WithGzip()) func WithGzip(level ...int) Option { return func(e *Engine) { l := gzip.DefaultCompression @@ -180,6 +426,10 @@ func WithGzip(level ...int) Option { // WithBrotli adds Brotli response compression middleware using andybalholm/brotli. // An optional compression level may be supplied (e.g. BrotliBestSpeed, // BrotliBestCompression). If omitted, BrotliDefaultCompression is used. +// +// Example: +// +// api.New(api.WithBrotli()) func WithBrotli(level ...int) Option { return func(e *Engine) { l := BrotliDefaultCompression @@ -193,6 +443,10 @@ func WithBrotli(level ...int) Option { // WithSlog adds structured request logging middleware via gin-contrib/slog. // Each request is logged with method, path, status code, latency, and client IP. // If logger is nil, slog.Default() is used. +// +// Example: +// +// api.New(api.WithSlog(nil)) func WithSlog(logger *slog.Logger) Option { return func(e *Engine) { if logger == nil { @@ -214,8 +468,15 @@ func WithSlog(logger *slog.Logger) Option { // // A zero or negative duration effectively disables the timeout (the handler // runs without a deadline) — this is safe and will not panic. +// +// Example: +// +// api.New(api.WithTimeout(5 * time.Second)) func WithTimeout(d time.Duration) Option { return func(e *Engine) { + if d <= 0 { + return + } e.middlewares = append(e.middlewares, timeout.New( timeout.WithTimeout(d), timeout.WithResponse(timeoutResponse), @@ -232,17 +493,77 @@ func timeoutResponse(c *gin.Context) { // Successful (2xx) GET responses are cached for the given TTL and served // with an X-Cache: HIT header on subsequent requests. Non-GET methods // and error responses pass through uncached. -func WithCache(ttl time.Duration) Option { +// +// Optional integer limits enable LRU eviction: +// - maxEntries limits the number of cached responses +// - maxBytes limits the approximate total cached payload size +// +// Pass a non-positive value to either limit to leave that dimension +// unbounded for backward compatibility. A non-positive TTL disables the +// middleware entirely. +// +// Example: +// +// engine, _ := api.New(api.WithCache(5*time.Minute, 100, 10<<20)) +func WithCache(ttl time.Duration, maxEntries ...int) Option { + entryLimit := 0 + byteLimit := 0 + if len(maxEntries) > 0 { + entryLimit = maxEntries[0] + } + if len(maxEntries) > 1 { + byteLimit = maxEntries[1] + } + return WithCacheLimits(ttl, entryLimit, byteLimit) +} + +// WithCacheLimits adds in-memory response caching middleware for GET requests +// with explicit entry and payload-size bounds. +// +// This is the clearer form of WithCache when call sites want to make the +// eviction policy self-documenting. +// +// Example: +// +// engine, _ := api.New(api.WithCacheLimits(5*time.Minute, 100, 10<<20)) +func WithCacheLimits(ttl time.Duration, maxEntries, maxBytes int) Option { return func(e *Engine) { - store := newCacheStore() + if ttl <= 0 { + return + } + e.cacheTTL = ttl + e.cacheMaxEntries = maxEntries + e.cacheMaxBytes = maxBytes + store := newCacheStore(maxEntries, maxBytes) e.middlewares = append(e.middlewares, cacheMiddleware(store, ttl)) } } +// WithRateLimit adds token-bucket rate limiting middleware. +// Requests are bucketed by API key or bearer token when present, and +// otherwise by client IP. Passing requests are annotated with +// X-RateLimit-Limit, X-RateLimit-Remaining, and X-RateLimit-Reset headers. +// Requests exceeding the configured limit are rejected with 429 Too Many +// Requests, Retry-After, and the standard Fail() error envelope. +// A zero or negative limit disables rate limiting. +// +// Example: +// +// engine, _ := api.New(api.WithRateLimit(100)) +func WithRateLimit(limit int) Option { + return func(e *Engine) { + e.middlewares = append(e.middlewares, rateLimitMiddleware(limit)) + } +} + // WithSessions adds server-side session management middleware via // gin-contrib/sessions using a cookie-based store. The name parameter // sets the session cookie name (e.g. "session") and secret is the key // used for cookie signing and encryption. +// +// Example: +// +// api.New(api.WithSessions("session", []byte("secret"))) func WithSessions(name string, secret []byte) Option { return func(e *Engine) { store := cookie.NewStore(secret) @@ -255,6 +576,10 @@ func WithSessions(name string, secret []byte) Option { // holding the desired model and policy rules. The middleware extracts the // subject from HTTP Basic Authentication, evaluates it against the request // method and path, and returns 403 Forbidden when the policy denies access. +// +// Example: +// +// api.New(api.WithAuthz(enforcer)) func WithAuthz(enforcer *casbin.Enforcer) Option { return func(e *Engine) { e.middlewares = append(e.middlewares, authz.NewAuthorizer(enforcer)) @@ -274,6 +599,10 @@ func WithAuthz(enforcer *casbin.Enforcer) Option { // // Requests with a missing, malformed, or invalid signature are rejected with // 401 Unauthorised or 400 Bad Request. +// +// Example: +// +// api.New(api.WithHTTPSign(secrets)) func WithHTTPSign(secrets httpsign.Secrets, opts ...httpsign.Option) Option { return func(e *Engine) { auth := httpsign.NewAuthenticator(secrets, opts...) @@ -281,16 +610,34 @@ func WithHTTPSign(secrets httpsign.Secrets, opts ...httpsign.Option) Option { } } -// WithSSE registers a Server-Sent Events broker at GET /events. -// Clients connect to the endpoint and receive a streaming text/event-stream -// response. The broker manages client connections and broadcasts events +// WithSSE registers a Server-Sent Events broker at the configured path. +// By default the endpoint is mounted at GET /events; use WithSSEPath to +// customise the route. Clients receive a streaming text/event-stream +// response and the broker manages client connections and broadcasts events // published via its Publish method. +// +// Example: +// +// broker := api.NewSSEBroker() +// engine, _ := api.New(api.WithSSE(broker)) func WithSSE(broker *SSEBroker) Option { return func(e *Engine) { e.sseBroker = broker } } +// WithSSEPath sets a custom URL path for the SSE endpoint. +// The default path is "/events". +// +// Example: +// +// api.New(api.WithSSEPath("/stream")) +func WithSSEPath(path string) Option { + return func(e *Engine) { + e.ssePath = normaliseSSEPath(path) + } +} + // WithLocation adds reverse proxy header detection middleware via // gin-contrib/location. It inspects X-Forwarded-Proto and X-Forwarded-Host // headers to determine the original scheme and host when the server runs @@ -298,6 +645,10 @@ func WithSSE(broker *SSEBroker) Option { // // After this middleware runs, handlers can call location.Get(c) to retrieve // a *url.URL with the detected scheme, host, and base path. +// +// Example: +// +// engine, _ := api.New(api.WithLocation()) func WithLocation() Option { return func(e *Engine) { e.middlewares = append(e.middlewares, location.Default()) @@ -311,6 +662,10 @@ func WithLocation() Option { // api.New( // api.WithGraphQL(schema, api.WithPlayground(), api.WithGraphQLPath("/gql")), // ) +// +// Example: +// +// engine, _ := api.New(api.WithGraphQL(schema)) func WithGraphQL(schema graphql.ExecutableSchema, opts ...GraphQLOption) Option { return func(e *Engine) { cfg := &graphqlConfig{ diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 1399c79..0358ef6 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -1,4 +1,4 @@ -// SPDX-Licence-Identifier: EUPL-1.2 +// SPDX-License-Identifier: EUPL-1.2 // Package provider defines the Service Provider Framework interfaces. // diff --git a/pkg/provider/proxy.go b/pkg/provider/proxy.go index 9eb2b4a..e2ef86b 100644 --- a/pkg/provider/proxy.go +++ b/pkg/provider/proxy.go @@ -1,13 +1,15 @@ -// SPDX-Licence-Identifier: EUPL-1.2 +// SPDX-License-Identifier: EUPL-1.2 package provider import ( + "fmt" "net/http" "net/http/httputil" "net/url" "strings" + coreapi "dappco.re/go/core/api" "github.com/gin-gonic/gin" ) @@ -39,14 +41,30 @@ type ProxyConfig struct { type ProxyProvider struct { config ProxyConfig proxy *httputil.ReverseProxy + err error } // NewProxy creates a ProxyProvider from the given configuration. -// The upstream URL must be valid or NewProxy will panic. +// Invalid upstream URLs do not panic; the provider retains the +// configuration error and responds with a standard 500 envelope when +// mounted. This keeps provider construction safe for callers. func NewProxy(cfg ProxyConfig) *ProxyProvider { target, err := url.Parse(cfg.Upstream) if err != nil { - panic("provider.NewProxy: invalid upstream URL: " + err.Error()) + return &ProxyProvider{ + config: cfg, + err: err, + } + } + + // url.Parse accepts inputs like "127.0.0.1:9901" without error — they + // parse without a scheme or host, which causes httputil.ReverseProxy to + // fail silently at runtime. Require both to be present. + if target.Scheme == "" || target.Host == "" { + return &ProxyProvider{ + config: cfg, + err: fmt.Errorf("upstream %q must include a scheme and host (e.g. http://127.0.0.1:9901)", cfg.Upstream), + } } proxy := httputil.NewSingleHostReverseProxy(target) @@ -59,11 +77,10 @@ func NewProxy(cfg ProxyConfig) *ProxyProvider { proxy.Director = func(req *http.Request) { defaultDirector(req) // Strip the base path prefix from the request path. - req.URL.Path = strings.TrimPrefix(req.URL.Path, basePath) - if req.URL.Path == "" { - req.URL.Path = "/" + req.URL.Path = stripBasePath(req.URL.Path, basePath) + if req.URL.RawPath != "" { + req.URL.RawPath = stripBasePath(req.URL.RawPath, basePath) } - req.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, basePath) } return &ProxyProvider{ @@ -72,6 +89,43 @@ func NewProxy(cfg ProxyConfig) *ProxyProvider { } } +// Err reports any configuration error detected while constructing the proxy. +// A nil error means the proxy is ready to mount and serve requests. +func (p *ProxyProvider) Err() error { + if p == nil { + return nil + } + return p.err +} + +// stripBasePath removes an exact base path prefix from a request path. +// It only strips when the path matches the base path itself or lives under +// the base path boundary, so "/api" will not accidentally trim "/api-v2". +func stripBasePath(path, basePath string) string { + basePath = strings.TrimSuffix(strings.TrimSpace(basePath), "/") + if basePath == "" || basePath == "/" { + if path == "" { + return "/" + } + return path + } + + if path == basePath { + return "/" + } + + prefix := basePath + "/" + if strings.HasPrefix(path, prefix) { + trimmed := strings.TrimPrefix(path, basePath) + if trimmed == "" { + return "/" + } + return trimmed + } + + return path +} + // Name returns the provider identity. func (p *ProxyProvider) Name() string { return p.config.Name @@ -85,6 +139,19 @@ func (p *ProxyProvider) BasePath() string { // RegisterRoutes mounts a catch-all reverse proxy handler on the router group. func (p *ProxyProvider) RegisterRoutes(rg *gin.RouterGroup) { rg.Any("/*path", func(c *gin.Context) { + if p == nil || p.err != nil || p.proxy == nil { + details := map[string]any{} + if p != nil && p.err != nil { + details["error"] = p.err.Error() + } + c.JSON(http.StatusInternalServerError, coreapi.FailWithDetails( + "invalid_provider_configuration", + "Provider is misconfigured", + details, + )) + return + } + // Use the underlying http.ResponseWriter directly. Gin's // responseWriter wrapper does not implement http.CloseNotifier, // which httputil.ReverseProxy requires for cancellation signalling. diff --git a/pkg/provider/proxy_internal_test.go b/pkg/provider/proxy_internal_test.go new file mode 100644 index 0000000..133d678 --- /dev/null +++ b/pkg/provider/proxy_internal_test.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package provider + +import "testing" + +func TestStripBasePath_Good_ExactBoundary(t *testing.T) { + got := stripBasePath("/api/v1/cool-widget/items", "/api/v1/cool-widget") + if got != "/items" { + t.Fatalf("expected stripped path %q, got %q", "/items", got) + } +} + +func TestStripBasePath_Good_RootPath(t *testing.T) { + got := stripBasePath("/api/v1/cool-widget", "/api/v1/cool-widget") + if got != "/" { + t.Fatalf("expected stripped root path %q, got %q", "/", got) + } +} + +func TestStripBasePath_Good_DoesNotTrimPartialPrefix(t *testing.T) { + got := stripBasePath("/api/v1/cool-widget-2/items", "/api/v1/cool-widget") + if got != "/api/v1/cool-widget-2/items" { + t.Fatalf("expected partial prefix to remain unchanged, got %q", got) + } +} diff --git a/pkg/provider/proxy_test.go b/pkg/provider/proxy_test.go index c1bc536..5f15253 100644 --- a/pkg/provider/proxy_test.go +++ b/pkg/provider/proxy_test.go @@ -183,11 +183,32 @@ func TestProxyProvider_Renderable_Good(t *testing.T) { } func TestProxyProvider_Ugly_InvalidUpstream(t *testing.T) { - assert.Panics(t, func() { - provider.NewProxy(provider.ProxyConfig{ - Name: "bad", - BasePath: "/api/v1/bad", - Upstream: "://not-a-url", - }) + p := provider.NewProxy(provider.ProxyConfig{ + Name: "bad", + BasePath: "/api/v1/bad", + Upstream: "://not-a-url", }) + + require.NotNil(t, p) + assert.Error(t, p.Err()) + + engine, err := api.New() + require.NoError(t, err) + engine.Register(p) + + handler := engine.Handler() + + req := httptest.NewRequest(http.MethodGet, "/api/v1/bad/items", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + + assert.Equal(t, false, body["success"]) + errObj, ok := body["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "invalid_provider_configuration", errObj["code"]) } diff --git a/pkg/provider/registry.go b/pkg/provider/registry.go index 89ce596..fdd40c2 100644 --- a/pkg/provider/registry.go +++ b/pkg/provider/registry.go @@ -1,4 +1,4 @@ -// SPDX-Licence-Identifier: EUPL-1.2 +// SPDX-License-Identifier: EUPL-1.2 package provider @@ -88,6 +88,24 @@ func (r *Registry) Streamable() []Streamable { return result } +// StreamableIter returns an iterator over all registered providers that +// implement the Streamable interface. +func (r *Registry) StreamableIter() iter.Seq[Streamable] { + r.mu.RLock() + providers := slices.Clone(r.providers) + r.mu.RUnlock() + + return func(yield func(Streamable) bool) { + for _, p := range providers { + if s, ok := p.(Streamable); ok { + if !yield(s) { + return + } + } + } + } +} + // Describable returns all providers that implement the Describable interface. func (r *Registry) Describable() []Describable { r.mu.RLock() @@ -101,6 +119,24 @@ func (r *Registry) Describable() []Describable { return result } +// DescribableIter returns an iterator over all registered providers that +// implement the Describable interface. +func (r *Registry) DescribableIter() iter.Seq[Describable] { + r.mu.RLock() + providers := slices.Clone(r.providers) + r.mu.RUnlock() + + return func(yield func(Describable) bool) { + for _, p := range providers { + if d, ok := p.(Describable); ok { + if !yield(d) { + return + } + } + } + } +} + // Renderable returns all providers that implement the Renderable interface. func (r *Registry) Renderable() []Renderable { r.mu.RLock() @@ -114,12 +150,32 @@ func (r *Registry) Renderable() []Renderable { return result } +// RenderableIter returns an iterator over all registered providers that +// implement the Renderable interface. +func (r *Registry) RenderableIter() iter.Seq[Renderable] { + r.mu.RLock() + providers := slices.Clone(r.providers) + r.mu.RUnlock() + + return func(yield func(Renderable) bool) { + for _, p := range providers { + if rv, ok := p.(Renderable); ok { + if !yield(rv) { + return + } + } + } + } +} + // ProviderInfo is a serialisable summary of a registered provider. type ProviderInfo struct { Name string `json:"name"` BasePath string `json:"basePath"` Channels []string `json:"channels,omitempty"` Element *ElementSpec `json:"element,omitempty"` + SpecFile string `json:"specFile,omitempty"` + Upstream string `json:"upstream,omitempty"` } // Info returns a summary of all registered providers. @@ -140,7 +196,76 @@ func (r *Registry) Info() []ProviderInfo { elem := rv.Element() info.Element = &elem } + if sf, ok := p.(interface{ SpecFile() string }); ok { + info.SpecFile = sf.SpecFile() + } + if up, ok := p.(interface{ Upstream() string }); ok { + info.Upstream = up.Upstream() + } infos = append(infos, info) } return infos } + +// InfoIter returns an iterator over all registered provider summaries. +// The iterator snapshots the current registry contents so callers can range +// over it without holding the registry lock. +func (r *Registry) InfoIter() iter.Seq[ProviderInfo] { + r.mu.RLock() + providers := slices.Clone(r.providers) + r.mu.RUnlock() + + return func(yield func(ProviderInfo) bool) { + for _, p := range providers { + info := ProviderInfo{ + Name: p.Name(), + BasePath: p.BasePath(), + } + if s, ok := p.(Streamable); ok { + info.Channels = s.Channels() + } + if rv, ok := p.(Renderable); ok { + elem := rv.Element() + info.Element = &elem + } + if sf, ok := p.(interface{ SpecFile() string }); ok { + info.SpecFile = sf.SpecFile() + } + if up, ok := p.(interface{ Upstream() string }); ok { + info.Upstream = up.Upstream() + } + if !yield(info) { + return + } + } + } +} + +// SpecFiles returns all non-empty provider OpenAPI spec file paths. +// The result is deduplicated and sorted for stable discovery output. +func (r *Registry) SpecFiles() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + files := make(map[string]struct{}, len(r.providers)) + for _, p := range r.providers { + if sf, ok := p.(interface{ SpecFile() string }); ok { + if path := sf.SpecFile(); path != "" { + files[path] = struct{}{} + } + } + } + + out := make([]string, 0, len(files)) + for path := range files { + out = append(out, path) + } + + slices.Sort(out) + return out +} + +// SpecFilesIter returns an iterator over all non-empty provider OpenAPI spec files. +func (r *Registry) SpecFilesIter() iter.Seq[string] { + return slices.Values(r.SpecFiles()) +} diff --git a/pkg/provider/registry_test.go b/pkg/provider/registry_test.go index 3ff09f8..37d2fb2 100644 --- a/pkg/provider/registry_test.go +++ b/pkg/provider/registry_test.go @@ -38,6 +38,13 @@ func (r *renderableProvider) Element() provider.ElementSpec { return provider.ElementSpec{Tag: "core-stub-panel", Source: "/assets/stub.js"} } +type specFileProvider struct { + stubProvider + specFile string +} + +func (s *specFileProvider) SpecFile() string { return s.specFile } + type fullProvider struct { streamableProvider } @@ -112,9 +119,39 @@ func TestRegistry_Streamable_Good(t *testing.T) { assert.Equal(t, []string{"stub.event"}, s[0].Channels()) } +func TestRegistry_StreamableIter_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&stubProvider{}) + reg.Add(&streamableProvider{}) + + var streamables []provider.Streamable + for s := range reg.StreamableIter() { + streamables = append(streamables, s) + } + + assert.Len(t, streamables, 1) + assert.Equal(t, []string{"stub.event"}, streamables[0].Channels()) +} + +func TestRegistry_StreamableIter_Good_SnapshotCurrentProviders(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&streamableProvider{}) + + iter := reg.StreamableIter() + reg.Add(&streamableProvider{}) + + var streamables []provider.Streamable + for s := range iter { + streamables = append(streamables, s) + } + + assert.Len(t, streamables, 1) + assert.Equal(t, []string{"stub.event"}, streamables[0].Channels()) +} + func TestRegistry_Describable_Good(t *testing.T) { reg := provider.NewRegistry() - reg.Add(&stubProvider{}) // not describable + reg.Add(&stubProvider{}) // not describable reg.Add(&describableProvider{}) // describable d := reg.Describable() @@ -122,6 +159,36 @@ func TestRegistry_Describable_Good(t *testing.T) { assert.Len(t, d[0].Describe(), 1) } +func TestRegistry_DescribableIter_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&stubProvider{}) + reg.Add(&describableProvider{}) + + var describables []provider.Describable + for d := range reg.DescribableIter() { + describables = append(describables, d) + } + + assert.Len(t, describables, 1) + assert.Len(t, describables[0].Describe(), 1) +} + +func TestRegistry_DescribableIter_Good_SnapshotCurrentProviders(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&describableProvider{}) + + iter := reg.DescribableIter() + reg.Add(&describableProvider{}) + + var describables []provider.Describable + for d := range iter { + describables = append(describables, d) + } + + assert.Len(t, describables, 1) + assert.Len(t, describables[0].Describe(), 1) +} + func TestRegistry_Renderable_Good(t *testing.T) { reg := provider.NewRegistry() reg.Add(&stubProvider{}) // not renderable @@ -132,6 +199,36 @@ func TestRegistry_Renderable_Good(t *testing.T) { assert.Equal(t, "core-stub-panel", r[0].Element().Tag) } +func TestRegistry_RenderableIter_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&stubProvider{}) + reg.Add(&renderableProvider{}) + + var renderables []provider.Renderable + for r := range reg.RenderableIter() { + renderables = append(renderables, r) + } + + assert.Len(t, renderables, 1) + assert.Equal(t, "core-stub-panel", renderables[0].Element().Tag) +} + +func TestRegistry_RenderableIter_Good_SnapshotCurrentProviders(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&renderableProvider{}) + + iter := reg.RenderableIter() + reg.Add(&renderableProvider{}) + + var renderables []provider.Renderable + for r := range iter { + renderables = append(renderables, r) + } + + assert.Len(t, renderables, 1) + assert.Equal(t, "core-stub-panel", renderables[0].Element().Tag) +} + func TestRegistry_Info_Good(t *testing.T) { reg := provider.NewRegistry() reg.Add(&fullProvider{}) @@ -147,6 +244,59 @@ func TestRegistry_Info_Good(t *testing.T) { assert.Equal(t, "core-full-panel", info.Element.Tag) } +func TestRegistry_Info_Good_ProxyMetadata(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(provider.NewProxy(provider.ProxyConfig{ + Name: "proxy", + BasePath: "/api/proxy", + Upstream: "http://127.0.0.1:9999", + SpecFile: "/tmp/proxy-openapi.json", + })) + + infos := reg.Info() + require.Len(t, infos, 1) + + info := infos[0] + assert.Equal(t, "proxy", info.Name) + assert.Equal(t, "/api/proxy", info.BasePath) + assert.Equal(t, "/tmp/proxy-openapi.json", info.SpecFile) + assert.Equal(t, "http://127.0.0.1:9999", info.Upstream) +} + +func TestRegistry_InfoIter_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&fullProvider{}) + + var infos []provider.ProviderInfo + for info := range reg.InfoIter() { + infos = append(infos, info) + } + + require.Len(t, infos, 1) + info := infos[0] + assert.Equal(t, "full", info.Name) + assert.Equal(t, "/api/full", info.BasePath) + assert.Equal(t, []string{"stub.event"}, info.Channels) + require.NotNil(t, info.Element) + assert.Equal(t, "core-full-panel", info.Element.Tag) +} + +func TestRegistry_InfoIter_Good_SnapshotCurrentProviders(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&fullProvider{}) + + iter := reg.InfoIter() + reg.Add(&specFileProvider{specFile: "/tmp/later.json"}) + + var infos []provider.ProviderInfo + for info := range iter { + infos = append(infos, info) + } + + require.Len(t, infos, 1) + assert.Equal(t, "full", infos[0].Name) +} + func TestRegistry_Iter_Good(t *testing.T) { reg := provider.NewRegistry() reg.Add(&stubProvider{}) @@ -158,3 +308,27 @@ func TestRegistry_Iter_Good(t *testing.T) { } assert.Equal(t, 2, count) } + +func TestRegistry_SpecFiles_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&stubProvider{}) + reg.Add(&specFileProvider{specFile: "/tmp/b.json"}) + reg.Add(&specFileProvider{specFile: "/tmp/a.yaml"}) + reg.Add(&specFileProvider{specFile: "/tmp/a.yaml"}) + reg.Add(&specFileProvider{specFile: ""}) + + assert.Equal(t, []string{"/tmp/a.yaml", "/tmp/b.json"}, reg.SpecFiles()) +} + +func TestRegistry_SpecFilesIter_Good(t *testing.T) { + reg := provider.NewRegistry() + reg.Add(&specFileProvider{specFile: "/tmp/z.json"}) + reg.Add(&specFileProvider{specFile: "/tmp/x.json"}) + + var files []string + for file := range reg.SpecFilesIter() { + files = append(files, file) + } + + assert.Equal(t, []string{"/tmp/x.json", "/tmp/z.json"}, files) +} diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..308ce7e --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,276 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "crypto/sha256" + "encoding/hex" + "math" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +const ( + rateLimitCleanupInterval = time.Minute + rateLimitStaleAfter = 10 * time.Minute + + // rateLimitMaxBuckets caps the total number of tracked keys to prevent + // unbounded memory growth under high-cardinality traffic (e.g. scanning + // bots cycling random IPs). When the cap is reached, new keys that cannot + // evict a stale bucket are routed to a shared overflow bucket so requests + // are still rate-limited rather than bypassing the limiter entirely. + rateLimitMaxBuckets = 100_000 + rateLimitOverflowKey = "__overflow__" +) + +type rateLimitStore struct { + mu sync.Mutex + buckets map[string]*rateLimitBucket + limit int + lastSweep time.Time +} + +type rateLimitBucket struct { + mu sync.Mutex + tokens float64 + last time.Time + lastSeen time.Time +} + +type rateLimitDecision struct { + allowed bool + retryAfter time.Duration + limit int + remaining int + resetAt time.Time +} + +func newRateLimitStore(limit int) *rateLimitStore { + now := time.Now() + return &rateLimitStore{ + buckets: make(map[string]*rateLimitBucket), + limit: limit, + lastSweep: now, + } +} + +func (s *rateLimitStore) allow(key string) rateLimitDecision { + now := time.Now() + + s.mu.Lock() + bucket, ok := s.buckets[key] + if !ok || now.Sub(bucket.lastSeen) > rateLimitStaleAfter { + // Enforce the bucket cap before inserting a new entry. First try to + // evict a single stale entry; if none exists and the map is full, + // route the request to the shared overflow bucket so it is still + // rate-limited rather than bypassing the limiter. + if !ok && len(s.buckets) >= rateLimitMaxBuckets { + evicted := false + for k, candidate := range s.buckets { + if now.Sub(candidate.lastSeen) > rateLimitStaleAfter { + delete(s.buckets, k) + evicted = true + break + } + } + if !evicted { + // Cap reached and no stale entry to evict: use overflow bucket. + key = rateLimitOverflowKey + if ob, exists := s.buckets[key]; exists { + bucket = ob + ok = true + } + } + } + + if !ok { + bucket = &rateLimitBucket{ + tokens: float64(s.limit), + last: now, + lastSeen: now, + } + s.buckets[key] = bucket + } else { + bucket.lastSeen = now + } + } else { + bucket.lastSeen = now + } + + if now.Sub(s.lastSweep) >= rateLimitCleanupInterval { + for k, candidate := range s.buckets { + if now.Sub(candidate.lastSeen) > rateLimitStaleAfter { + delete(s.buckets, k) + } + } + s.lastSweep = now + } + s.mu.Unlock() + + bucket.mu.Lock() + defer bucket.mu.Unlock() + + elapsed := now.Sub(bucket.last) + if elapsed > 0 { + refill := elapsed.Seconds() * float64(s.limit) + if bucket.tokens+refill > float64(s.limit) { + bucket.tokens = float64(s.limit) + } else { + bucket.tokens += refill + } + bucket.last = now + } + + if bucket.tokens >= 1 { + bucket.tokens-- + return rateLimitDecision{ + allowed: true, + limit: s.limit, + remaining: int(math.Floor(bucket.tokens)), + resetAt: now.Add(timeUntilFull(bucket.tokens, s.limit)), + } + } + + deficit := 1 - bucket.tokens + wait := time.Duration(deficit / float64(s.limit) * float64(time.Second)) + if wait <= 0 { + wait = time.Second / time.Duration(s.limit) + if wait <= 0 { + wait = time.Second + } + } + + return rateLimitDecision{ + allowed: false, + retryAfter: wait, + limit: s.limit, + remaining: 0, + resetAt: now.Add(wait), + } +} + +func rateLimitMiddleware(limit int) gin.HandlerFunc { + if limit <= 0 { + return func(c *gin.Context) { + c.Next() + } + } + + store := newRateLimitStore(limit) + + return func(c *gin.Context) { + key := clientRateLimitKey(c) + decision := store.allow(key) + if !decision.allowed { + secs := int(decision.retryAfter / time.Second) + if decision.retryAfter%time.Second != 0 { + secs++ + } + if secs < 1 { + secs = 1 + } + setRateLimitHeaders(c, decision.limit, decision.remaining, decision.resetAt) + c.Header("Retry-After", strconv.Itoa(secs)) + c.AbortWithStatusJSON(http.StatusTooManyRequests, Fail( + "rate_limit_exceeded", + "Too many requests", + )) + return + } + + setRateLimitHeaders(c, decision.limit, decision.remaining, decision.resetAt) + c.Next() + } +} + +func setRateLimitHeaders(c *gin.Context, limit, remaining int, resetAt time.Time) { + if limit > 0 { + c.Header("X-RateLimit-Limit", strconv.Itoa(limit)) + } + if remaining < 0 { + remaining = 0 + } + c.Header("X-RateLimit-Remaining", strconv.Itoa(remaining)) + if !resetAt.IsZero() { + reset := resetAt.Unix() + if reset <= time.Now().Unix() { + reset = time.Now().Add(time.Second).Unix() + } + c.Header("X-RateLimit-Reset", strconv.FormatInt(reset, 10)) + } +} + +func timeUntilFull(tokens float64, limit int) time.Duration { + if limit <= 0 { + return 0 + } + missing := float64(limit) - tokens + if missing <= 0 { + return 0 + } + seconds := missing / float64(limit) + if seconds <= 0 { + return 0 + } + return time.Duration(math.Ceil(seconds * float64(time.Second))) +} + +// clientRateLimitKey derives a bucket key for the request. It prefers a +// validated principal placed in context by auth middleware, then falls back to +// raw credential headers (X-API-Key or Bearer token, hashed with SHA-256 so +// secrets are never stored in the bucket map), and finally falls back to the +// client IP when no credentials are present. +func clientRateLimitKey(c *gin.Context) string { + // Prefer a validated principal placed in context by auth middleware. + if principal, ok := c.Get("principal"); ok && principal != nil { + if s, ok := principal.(string); ok && s != "" { + return "principal:" + s + } + } + if userID, ok := c.Get("userID"); ok && userID != nil { + if s, ok := userID.(string); ok && s != "" { + return "user:" + s + } + } + + // Fall back to credential headers before the IP so that different API + // keys coming from the same NAT address are bucketed independently. The + // raw secret is never stored — it is hashed with SHA-256 first. + if apiKey := strings.TrimSpace(c.GetHeader("X-API-Key")); apiKey != "" { + h := sha256.Sum256([]byte(apiKey)) + return "cred:sha256:" + hex.EncodeToString(h[:]) + } + if bearer := bearerTokenFromHeader(c.GetHeader("Authorization")); bearer != "" { + h := sha256.Sum256([]byte(bearer)) + return "cred:sha256:" + hex.EncodeToString(h[:]) + } + + // Last resort: fall back to IP address. + if ip := c.ClientIP(); ip != "" { + return "ip:" + ip + } + if c.Request != nil && c.Request.RemoteAddr != "" { + return "ip:" + c.Request.RemoteAddr + } + + return "ip:unknown" +} + +func bearerTokenFromHeader(header string) string { + header = strings.TrimSpace(header) + if header == "" { + return "" + } + + parts := strings.SplitN(header, " ", 2) + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { + return "" + } + + return strings.TrimSpace(parts[1]) +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..24d75ed --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/core/api" +) + +type rateLimitTestGroup struct{} + +func (r *rateLimitTestGroup) Name() string { return "rate-limit" } +func (r *rateLimitTestGroup) BasePath() string { return "/rate" } +func (r *rateLimitTestGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/ping", func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("pong")) + }) +} + +func TestWithRateLimit_Good_AllowsBurstThenRejects(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(2)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first request to succeed, got %d", w1.Code) + } + if got := w1.Header().Get("X-RateLimit-Limit"); got != "2" { + t.Fatalf("expected X-RateLimit-Limit=2, got %q", got) + } + if got := w1.Header().Get("X-RateLimit-Remaining"); got != "1" { + t.Fatalf("expected X-RateLimit-Remaining=1, got %q", got) + } + if got := w1.Header().Get("X-RateLimit-Reset"); got == "" { + t.Fatal("expected X-RateLimit-Reset on successful response") + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second request to succeed, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected third request to be rate limited, got %d", w3.Code) + } + + if got := w3.Header().Get("Retry-After"); got == "" { + t.Fatal("expected Retry-After header on 429 response") + } + if got := w3.Header().Get("X-RateLimit-Limit"); got != "2" { + t.Fatalf("expected X-RateLimit-Limit=2 on 429, got %q", got) + } + if got := w3.Header().Get("X-RateLimit-Remaining"); got != "0" { + t.Fatalf("expected X-RateLimit-Remaining=0 on 429, got %q", got) + } + if got := w3.Header().Get("X-RateLimit-Reset"); got == "" { + t.Fatal("expected X-RateLimit-Reset on 429 response") + } + + var resp api.Response[any] + if err := json.Unmarshal(w3.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Success { + t.Fatal("expected Success=false for rate limited response") + } + if resp.Error == nil || resp.Error.Code != "rate_limit_exceeded" { + t.Fatalf("expected rate_limit_exceeded error, got %+v", resp.Error) + } +} + +func TestWithRateLimit_Good_IsolatesPerIP(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.10:1234" + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first IP to succeed, got %d", w1.Code) + } + if got := w1.Header().Get("X-RateLimit-Limit"); got != "1" { + t.Fatalf("expected X-RateLimit-Limit=1, got %q", got) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.11:1234" + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second IP to have its own bucket, got %d", w2.Code) + } +} + +func TestWithRateLimit_Good_IsolatesPerAPIKey(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.20:1234" + req1.Header.Set("X-API-Key", "key-a") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first API key request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.20:1234" + req2.Header.Set("X-API-Key", "key-b") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second API key to have its own bucket, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.20:1234" + req3.Header.Set("X-API-Key", "key-a") + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected repeated API key to be rate limited, got %d", w3.Code) + } +} + +func TestWithRateLimit_Good_UsesBearerTokenWhenPresent(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + w1 := httptest.NewRecorder() + req1, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req1.RemoteAddr = "203.0.113.30:1234" + req1.Header.Set("Authorization", "Bearer token-a") + h.ServeHTTP(w1, req1) + if w1.Code != http.StatusOK { + t.Fatalf("expected first bearer token request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req2.RemoteAddr = "203.0.113.30:1234" + req2.Header.Set("Authorization", "Bearer token-b") + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusOK { + t.Fatalf("expected second bearer token to have its own bucket, got %d", w2.Code) + } + + w3 := httptest.NewRecorder() + req3, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req3.RemoteAddr = "203.0.113.30:1234" + req3.Header.Set("Authorization", "Bearer token-a") + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusTooManyRequests { + t.Fatalf("expected repeated bearer token to be rate limited, got %d", w3.Code) + } +} + +func TestWithRateLimit_Good_RefillsOverTime(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(1)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + req, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req.RemoteAddr = "203.0.113.12:1234" + + w1 := httptest.NewRecorder() + h.ServeHTTP(w1, req.Clone(req.Context())) + if w1.Code != http.StatusOK { + t.Fatalf("expected first request to succeed, got %d", w1.Code) + } + + w2 := httptest.NewRecorder() + req2 := req.Clone(req.Context()) + req2.RemoteAddr = req.RemoteAddr + h.ServeHTTP(w2, req2) + if w2.Code != http.StatusTooManyRequests { + t.Fatalf("expected second request to be rate limited, got %d", w2.Code) + } + + time.Sleep(1100 * time.Millisecond) + + w3 := httptest.NewRecorder() + req3 := req.Clone(req.Context()) + req3.RemoteAddr = req.RemoteAddr + h.ServeHTTP(w3, req3) + if w3.Code != http.StatusOK { + t.Fatalf("expected bucket to refill after waiting, got %d", w3.Code) + } +} + +func TestWithRateLimit_Ugly_NonPositiveLimitDisablesMiddleware(t *testing.T) { + gin.SetMode(gin.TestMode) + e, _ := api.New(api.WithRateLimit(0)) + e.Register(&rateLimitTestGroup{}) + + h := e.Handler() + + for i := 0; i < 3; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/rate/ping", nil) + req.RemoteAddr = "203.0.113.13:1234" + h.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected request %d to succeed with disabled limiter, got %d", i+1, w.Code) + } + } +} diff --git a/response.go b/response.go index 2a77e18..45928ce 100644 --- a/response.go +++ b/response.go @@ -2,7 +2,14 @@ package api +import "github.com/gin-gonic/gin" + // Response is the standard envelope for all API responses. +// +// Example: +// +// resp := api.OK(map[string]any{"id": 42}) +// resp.Success // true type Response[T any] struct { Success bool `json:"success"` Data T `json:"data,omitempty"` @@ -11,6 +18,10 @@ type Response[T any] struct { } // Error describes a failed API request. +// +// Example: +// +// err := api.Error{Code: "invalid_input", Message: "Name is required"} type Error struct { Code string `json:"code"` Message string `json:"message"` @@ -18,6 +29,10 @@ type Error struct { } // Meta carries pagination and request metadata. +// +// Example: +// +// meta := api.Meta{RequestID: "req_123", Duration: "12ms"} type Meta struct { RequestID string `json:"request_id,omitempty"` Duration string `json:"duration,omitempty"` @@ -27,6 +42,10 @@ type Meta struct { } // OK wraps data in a successful response envelope. +// +// Example: +// +// c.JSON(http.StatusOK, api.OK(map[string]any{"name": "status"})) func OK[T any](data T) Response[T] { return Response[T]{ Success: true, @@ -35,6 +54,10 @@ func OK[T any](data T) Response[T] { } // Fail creates an error response with the given code and message. +// +// Example: +// +// c.JSON(http.StatusBadRequest, api.Fail("invalid_input", "Name is required")) func Fail(code, message string) Response[any] { return Response[any]{ Success: false, @@ -46,6 +69,10 @@ func Fail(code, message string) Response[any] { } // FailWithDetails creates an error response with additional detail payload. +// +// Example: +// +// c.JSON(http.StatusBadRequest, api.FailWithDetails("invalid_input", "Name is required", map[string]any{"field": "name"})) func FailWithDetails(code, message string, details any) Response[any] { return Response[any]{ Success: false, @@ -58,6 +85,10 @@ func FailWithDetails(code, message string, details any) Response[any] { } // Paginated wraps data in a successful response with pagination metadata. +// +// Example: +// +// c.JSON(http.StatusOK, api.Paginated(items, 2, 50, 200)) func Paginated[T any](data T, page, perPage, total int) Response[T] { return Response[T]{ Success: true, @@ -69,3 +100,31 @@ func Paginated[T any](data T, page, perPage, total int) Response[T] { }, } } + +// AttachRequestMeta merges request metadata into an existing response envelope. +// Existing pagination metadata is preserved; request_id and duration are added +// when available from the Gin context. +// +// Example: +// +// resp = api.AttachRequestMeta(c, resp) +func AttachRequestMeta[T any](c *gin.Context, resp Response[T]) Response[T] { + meta := GetRequestMeta(c) + if meta == nil { + return resp + } + + if resp.Meta == nil { + resp.Meta = meta + return resp + } + + if resp.Meta.RequestID == "" { + resp.Meta.RequestID = meta.RequestID + } + if resp.Meta.Duration == "" { + resp.Meta.Duration = meta.Duration + } + + return resp +} diff --git a/response_meta.go b/response_meta.go new file mode 100644 index 0000000..74f9e8a --- /dev/null +++ b/response_meta.go @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "mime" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// responseMetaRecorder buffers JSON responses so request metadata can be +// injected into the standard envelope before the body is written to the client. +type responseMetaRecorder struct { + gin.ResponseWriter + headers http.Header + body bytes.Buffer + size int + status int + wroteHeader bool + committed bool + passthrough bool +} + +func newResponseMetaRecorder(w gin.ResponseWriter) *responseMetaRecorder { + headers := make(http.Header) + for k, vals := range w.Header() { + headers[k] = append([]string(nil), vals...) + } + + return &responseMetaRecorder{ + ResponseWriter: w, + headers: headers, + status: http.StatusOK, + } +} + +func (w *responseMetaRecorder) Header() http.Header { + return w.headers +} + +func (w *responseMetaRecorder) WriteHeader(code int) { + if w.passthrough { + w.status = code + w.wroteHeader = true + w.ResponseWriter.WriteHeader(code) + return + } + w.status = code + w.wroteHeader = true +} + +func (w *responseMetaRecorder) WriteHeaderNow() { + if w.passthrough { + w.wroteHeader = true + w.ResponseWriter.WriteHeaderNow() + return + } + w.wroteHeader = true +} + +func (w *responseMetaRecorder) Write(data []byte) (int, error) { + if w.passthrough { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.Write(data) + } + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + n, err := w.body.Write(data) + w.size += n + return n, err +} + +func (w *responseMetaRecorder) WriteString(s string) (int, error) { + if w.passthrough { + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + return w.ResponseWriter.WriteString(s) + } + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + n, err := w.body.WriteString(s) + w.size += n + return n, err +} + +func (w *responseMetaRecorder) Flush() { + if w.passthrough { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } + return + } + + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + + w.commit(true) + w.passthrough = true + + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (w *responseMetaRecorder) Status() int { + if w.wroteHeader { + return w.status + } + + return http.StatusOK +} + +func (w *responseMetaRecorder) Size() int { + if w.passthrough { + return w.ResponseWriter.Size() + } + return w.size +} + +func (w *responseMetaRecorder) Written() bool { + return w.wroteHeader +} + +func (w *responseMetaRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if w.passthrough { + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, io.ErrClosedPipe + } + + w.wroteHeader = true + w.passthrough = true + + if h, ok := w.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, io.ErrClosedPipe +} + +func (w *responseMetaRecorder) commit(writeBody bool) { + if w.committed { + return + } + + for k := range w.ResponseWriter.Header() { + w.ResponseWriter.Header().Del(k) + } + + for k, vals := range w.headers { + for _, v := range vals { + w.ResponseWriter.Header().Add(k, v) + } + } + + w.ResponseWriter.WriteHeader(w.Status()) + if writeBody { + _, _ = w.ResponseWriter.Write(w.body.Bytes()) + w.body.Reset() + } + w.committed = true +} + +// responseMetaMiddleware injects request metadata into JSON envelope +// responses before they are written to the client. +func responseMetaMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if _, ok := c.Get(requestStartContextKey); !ok { + c.Set(requestStartContextKey, time.Now()) + } + + recorder := newResponseMetaRecorder(c.Writer) + c.Writer = recorder + + c.Next() + + if recorder.passthrough { + return + } + + body := recorder.body.Bytes() + if meta := GetRequestMeta(c); meta != nil && shouldAttachResponseMeta(recorder.Header().Get("Content-Type"), body) { + if refreshed := refreshResponseMetaBody(body, meta); refreshed != nil { + body = refreshed + } + } + + recorder.body.Reset() + _, _ = recorder.body.Write(body) + recorder.size = len(body) + recorder.Header().Set("Content-Length", strconv.Itoa(len(body))) + recorder.commit(true) + } +} + +// refreshResponseMetaBody injects request metadata into a cached or buffered +// JSON envelope without disturbing existing pagination metadata. +func refreshResponseMetaBody(body []byte, meta *Meta) []byte { + if meta == nil { + return body + } + + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + dec.UseNumber() + if err := dec.Decode(&payload); err != nil { + return body + } + + var extra any + if err := dec.Decode(&extra); err != io.EOF { + return body + } + + obj, ok := payload.(map[string]any) + if !ok { + return body + } + + if _, ok := obj["success"]; !ok { + if _, ok := obj["error"]; !ok { + return body + } + } + + current := map[string]any{} + if existing, ok := obj["meta"].(map[string]any); ok { + current = existing + } + + if meta.RequestID != "" { + current["request_id"] = meta.RequestID + } + if meta.Duration != "" { + current["duration"] = meta.Duration + } + + obj["meta"] = current + + updated, err := json.Marshal(obj) + if err != nil { + return body + } + + return updated +} + +func shouldAttachResponseMeta(contentType string, body []byte) bool { + if !isJSONContentType(contentType) { + return false + } + + trimmed := bytes.TrimSpace(body) + return len(trimmed) > 0 && trimmed[0] == '{' +} + +func isJSONContentType(contentType string) bool { + if strings.TrimSpace(contentType) == "" { + return false + } + + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + mediaType = strings.TrimSpace(contentType) + } + mediaType = strings.ToLower(mediaType) + + return mediaType == "application/json" || + strings.HasSuffix(mediaType, "+json") || + strings.HasSuffix(mediaType, "/json") +} diff --git a/runtime_config.go b/runtime_config.go new file mode 100644 index 0000000..3f537aa --- /dev/null +++ b/runtime_config.go @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +// RuntimeConfig captures the engine's current runtime-facing configuration in +// a single snapshot. +// +// It groups the existing Swagger, transport, GraphQL, cache, and i18n snapshots +// so callers can inspect the active engine surface without joining multiple +// method results themselves. +// +// Example: +// +// cfg := engine.RuntimeConfig() +type RuntimeConfig struct { + Swagger SwaggerConfig + Transport TransportConfig + GraphQL GraphQLConfig + Cache CacheConfig + I18n I18nConfig + Authentik AuthentikConfig +} + +// RuntimeConfig returns a stable snapshot of the engine's current runtime +// configuration. +// +// The result clones the underlying snapshots so callers can safely retain or +// modify the returned value. +// +// Example: +// +// cfg := engine.RuntimeConfig() +func (e *Engine) RuntimeConfig() RuntimeConfig { + if e == nil { + return RuntimeConfig{} + } + + return RuntimeConfig{ + Swagger: e.SwaggerConfig(), + Transport: e.TransportConfig(), + GraphQL: e.GraphQLConfig(), + Cache: e.CacheConfig(), + I18n: e.I18nConfig(), + Authentik: e.AuthentikConfig(), + } +} diff --git a/servers.go b/servers.go new file mode 100644 index 0000000..676eafe --- /dev/null +++ b/servers.go @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import "strings" + +// normaliseServers trims whitespace, removes empty entries, and preserves +// the first occurrence of each server URL. +func normaliseServers(servers []string) []string { + if len(servers) == 0 { + return nil + } + + cleaned := make([]string, 0, len(servers)) + seen := make(map[string]struct{}, len(servers)) + + for _, server := range servers { + server = normaliseServer(server) + if server == "" { + continue + } + if _, ok := seen[server]; ok { + continue + } + seen[server] = struct{}{} + cleaned = append(cleaned, server) + } + + if len(cleaned) == 0 { + return nil + } + + return cleaned +} + +// normaliseServer trims surrounding whitespace and removes a trailing slash +// from non-root server URLs so equivalent metadata collapses to one entry. +func normaliseServer(server string) string { + server = strings.TrimSpace(server) + if server == "" { + return "" + } + if server == "/" { + return server + } + + server = strings.TrimRight(server, "/") + if server == "" { + return "/" + } + + return server +} diff --git a/spec_builder_helper.go b/spec_builder_helper.go new file mode 100644 index 0000000..5eaa8e6 --- /dev/null +++ b/spec_builder_helper.go @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "reflect" + "slices" + "strings" +) + +// SwaggerConfig captures the configured Swagger/OpenAPI metadata for an Engine. +// +// It is intentionally small and serialisable so callers can inspect the active +// documentation surface without rebuilding an OpenAPI document. +// +// Example: +// +// cfg := api.SwaggerConfig{Title: "Service", Summary: "Public API"} +type SwaggerConfig struct { + Enabled bool + Path string + Title string + Summary string + Description string + Version string + TermsOfService string + ContactName string + ContactURL string + ContactEmail string + Servers []string + LicenseName string + LicenseURL string + SecuritySchemes map[string]any + ExternalDocsDescription string + ExternalDocsURL string +} + +// OpenAPISpecBuilder returns a SpecBuilder populated from the engine's current +// Swagger, transport, cache, i18n, and Authentik metadata. +// +// Example: +// +// builder := engine.OpenAPISpecBuilder() +func (e *Engine) OpenAPISpecBuilder() *SpecBuilder { + if e == nil { + return &SpecBuilder{} + } + + runtime := e.RuntimeConfig() + builder := &SpecBuilder{ + Title: runtime.Swagger.Title, + Summary: runtime.Swagger.Summary, + Description: runtime.Swagger.Description, + Version: runtime.Swagger.Version, + SwaggerEnabled: runtime.Swagger.Enabled, + TermsOfService: runtime.Swagger.TermsOfService, + ContactName: runtime.Swagger.ContactName, + ContactURL: runtime.Swagger.ContactURL, + ContactEmail: runtime.Swagger.ContactEmail, + Servers: slices.Clone(runtime.Swagger.Servers), + LicenseName: runtime.Swagger.LicenseName, + LicenseURL: runtime.Swagger.LicenseURL, + SecuritySchemes: cloneSecuritySchemes(runtime.Swagger.SecuritySchemes), + ExternalDocsDescription: runtime.Swagger.ExternalDocsDescription, + ExternalDocsURL: runtime.Swagger.ExternalDocsURL, + } + + builder.SwaggerPath = runtime.Transport.SwaggerPath + builder.GraphQLEnabled = runtime.GraphQL.Enabled + builder.GraphQLPath = runtime.GraphQL.Path + builder.GraphQLPlayground = runtime.GraphQL.Playground + builder.GraphQLPlaygroundPath = runtime.GraphQL.PlaygroundPath + builder.WSPath = runtime.Transport.WSPath + builder.WSEnabled = runtime.Transport.WSEnabled + builder.SSEPath = runtime.Transport.SSEPath + builder.SSEEnabled = runtime.Transport.SSEEnabled + builder.PprofEnabled = runtime.Transport.PprofEnabled + builder.ExpvarEnabled = runtime.Transport.ExpvarEnabled + + builder.CacheEnabled = runtime.Cache.Enabled + if runtime.Cache.TTL > 0 { + builder.CacheTTL = runtime.Cache.TTL.String() + } + builder.CacheMaxEntries = runtime.Cache.MaxEntries + builder.CacheMaxBytes = runtime.Cache.MaxBytes + + builder.I18nDefaultLocale = runtime.I18n.DefaultLocale + builder.I18nSupportedLocales = slices.Clone(runtime.I18n.Supported) + builder.AuthentikIssuer = runtime.Authentik.Issuer + builder.AuthentikClientID = runtime.Authentik.ClientID + builder.AuthentikTrustedProxy = runtime.Authentik.TrustedProxy + builder.AuthentikPublicPaths = slices.Clone(runtime.Authentik.PublicPaths) + + return builder +} + +// SwaggerConfig returns the currently configured Swagger metadata for the engine. +// +// The result snapshots the Engine state at call time and clones slices/maps so +// callers can safely reuse or modify the returned value. +// +// Example: +// +// cfg := engine.SwaggerConfig() +func (e *Engine) SwaggerConfig() SwaggerConfig { + if e == nil { + return SwaggerConfig{} + } + + cfg := SwaggerConfig{ + Enabled: e.swaggerEnabled, + Title: e.swaggerTitle, + Summary: e.swaggerSummary, + Description: e.swaggerDesc, + Version: e.swaggerVersion, + TermsOfService: e.swaggerTermsOfService, + ContactName: e.swaggerContactName, + ContactURL: e.swaggerContactURL, + ContactEmail: e.swaggerContactEmail, + Servers: slices.Clone(e.swaggerServers), + LicenseName: e.swaggerLicenseName, + LicenseURL: e.swaggerLicenseURL, + SecuritySchemes: cloneSecuritySchemes(e.swaggerSecuritySchemes), + ExternalDocsDescription: e.swaggerExternalDocsDescription, + ExternalDocsURL: e.swaggerExternalDocsURL, + } + + if strings.TrimSpace(e.swaggerPath) != "" { + cfg.Path = normaliseSwaggerPath(e.swaggerPath) + } + + return cfg +} + +func cloneSecuritySchemes(schemes map[string]any) map[string]any { + if len(schemes) == 0 { + return nil + } + + out := make(map[string]any, len(schemes)) + for name, scheme := range schemes { + out[name] = cloneOpenAPIValue(scheme) + } + return out +} + +func cloneRouteDescription(rd RouteDescription) RouteDescription { + out := rd + + out.Tags = slices.Clone(rd.Tags) + out.Security = cloneSecurityRequirements(rd.Security) + out.Parameters = cloneParameterDescriptions(rd.Parameters) + out.RequestBody = cloneOpenAPIObject(rd.RequestBody) + out.RequestExample = cloneOpenAPIValue(rd.RequestExample) + out.Response = cloneOpenAPIObject(rd.Response) + out.ResponseExample = cloneOpenAPIValue(rd.ResponseExample) + out.ResponseHeaders = cloneStringMap(rd.ResponseHeaders) + + return out +} + +func cloneParameterDescriptions(params []ParameterDescription) []ParameterDescription { + if params == nil { + return nil + } + if len(params) == 0 { + return []ParameterDescription{} + } + + out := make([]ParameterDescription, len(params)) + for i, param := range params { + out[i] = param + out[i].Schema = cloneOpenAPIObject(param.Schema) + out[i].Example = cloneOpenAPIValue(param.Example) + } + + return out +} + +func cloneSecurityRequirements(security []map[string][]string) []map[string][]string { + if security == nil { + return nil + } + if len(security) == 0 { + return []map[string][]string{} + } + + out := make([]map[string][]string, len(security)) + for i, requirement := range security { + if len(requirement) == 0 { + continue + } + + cloned := make(map[string][]string, len(requirement)) + for name, scopes := range requirement { + cloned[name] = slices.Clone(scopes) + } + out[i] = cloned + } + + return out +} + +func cloneOpenAPIObject(v map[string]any) map[string]any { + if v == nil { + return nil + } + if len(v) == 0 { + return map[string]any{} + } + + cloned, _ := cloneOpenAPIValue(v).(map[string]any) + return cloned +} + +func cloneStringMap(v map[string]string) map[string]string { + if v == nil { + return nil + } + if len(v) == 0 { + return map[string]string{} + } + + out := make(map[string]string, len(v)) + for key, value := range v { + out[key] = value + } + return out +} + +// cloneOpenAPIValue recursively copies JSON-like OpenAPI values so callers can +// safely retain and reuse their original maps after configuring an engine. +func cloneOpenAPIValue(v any) any { + switch value := v.(type) { + case map[string]any: + out := make(map[string]any, len(value)) + for k, nested := range value { + out[k] = cloneOpenAPIValue(nested) + } + return out + case []any: + out := make([]any, len(value)) + for i, nested := range value { + out[i] = cloneOpenAPIValue(nested) + } + return out + default: + rv := reflect.ValueOf(v) + if !rv.IsValid() { + return nil + } + + switch rv.Kind() { + case reflect.Map: + out := reflect.MakeMapWithSize(rv.Type(), rv.Len()) + for _, key := range rv.MapKeys() { + cloned := cloneOpenAPIValue(rv.MapIndex(key).Interface()) + if cloned == nil { + out.SetMapIndex(key, reflect.Zero(rv.Type().Elem())) + continue + } + out.SetMapIndex(key, reflect.ValueOf(cloned)) + } + return out.Interface() + case reflect.Slice: + if rv.IsNil() { + return v + } + out := reflect.MakeSlice(rv.Type(), rv.Len(), rv.Len()) + for i := 0; i < rv.Len(); i++ { + cloned := cloneOpenAPIValue(rv.Index(i).Interface()) + if cloned == nil { + out.Index(i).Set(reflect.Zero(rv.Type().Elem())) + continue + } + out.Index(i).Set(reflect.ValueOf(cloned)) + } + return out.Interface() + default: + return value + } + } +} diff --git a/spec_builder_helper_internal_test.go b/spec_builder_helper_internal_test.go new file mode 100644 index 0000000..eb47f7d --- /dev/null +++ b/spec_builder_helper_internal_test.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import "testing" + +func TestEngine_SwaggerConfig_Good_NormalisesPathAtSnapshot(t *testing.T) { + e := &Engine{ + swaggerPath: " /docs/ ", + } + + cfg := e.SwaggerConfig() + if cfg.Path != "/docs" { + t.Fatalf("expected normalised Swagger path /docs, got %q", cfg.Path) + } +} + +func TestEngine_TransportConfig_Good_NormalisesGraphQLPathAtSnapshot(t *testing.T) { + e := &Engine{ + graphql: &graphqlConfig{ + path: " /gql/ ", + }, + } + + cfg := e.TransportConfig() + if cfg.GraphQLPath != "/gql" { + t.Fatalf("expected normalised GraphQL path /gql, got %q", cfg.GraphQLPath) + } +} diff --git a/spec_builder_helper_test.go b/spec_builder_helper_test.go new file mode 100644 index 0000000..8654ed8 --- /dev/null +++ b/spec_builder_helper_test.go @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/gin-gonic/gin" + "slices" + + api "dappco.re/go/core/api" +) + +func TestEngine_Good_OpenAPISpecBuilderCarriesEngineMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithSwagger("Engine API", "Engine metadata", "2.0.0"), + api.WithSwaggerSummary("Engine overview"), + api.WithSwaggerPath("/docs"), + api.WithSwaggerTermsOfService("https://example.com/terms"), + api.WithSwaggerContact("API Support", "https://example.com/support", "support@example.com"), + api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"), + api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), + api.WithSwaggerSecuritySchemes(map[string]any{ + "apiKeyAuth": map[string]any{ + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + }, + }), + api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + api.WithCacheLimits(5*time.Minute, 42, 8192), + api.WithI18n(api.I18nConfig{ + DefaultLocale: "en-GB", + Supported: []string{"en-GB", "fr"}, + }), + api.WithAuthentik(api.AuthentikConfig{ + Issuer: "https://auth.example.com", + ClientID: "core-client", + TrustedProxy: true, + PublicPaths: []string{" /public/ ", "docs", "/public"}, + }), + api.WithWSPath("/socket"), + api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")), + api.WithSSE(broker), + api.WithSSEPath("/events"), + api.WithPprof(), + api.WithExpvar(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["title"] != "Engine API" { + t.Fatalf("expected title Engine API, got %v", info["title"]) + } + if info["description"] != "Engine metadata" { + t.Fatalf("expected description Engine metadata, got %v", info["description"]) + } + if info["version"] != "2.0.0" { + t.Fatalf("expected version 2.0.0, got %v", info["version"]) + } + if info["summary"] != "Engine overview" { + t.Fatalf("expected summary Engine overview, got %v", info["summary"]) + } + + if got := spec["x-swagger-ui-path"]; got != "/docs" { + t.Fatalf("expected x-swagger-ui-path=/docs, got %v", got) + } + if got := spec["x-swagger-enabled"]; got != true { + t.Fatalf("expected x-swagger-enabled=true, got %v", got) + } + if got := spec["x-graphql-enabled"]; got != true { + t.Fatalf("expected x-graphql-enabled=true, got %v", got) + } + if got := spec["x-graphql-path"]; got != "/gql" { + t.Fatalf("expected x-graphql-path=/gql, got %v", got) + } + if got := spec["x-graphql-playground"]; got != true { + t.Fatalf("expected x-graphql-playground=true, got %v", got) + } + if got := spec["x-graphql-playground-path"]; got != "/gql/playground" { + t.Fatalf("expected x-graphql-playground-path=/gql/playground, got %v", got) + } + if got := spec["x-ws-path"]; got != "/socket" { + t.Fatalf("expected x-ws-path=/socket, got %v", got) + } + if got := spec["x-ws-enabled"]; got != true { + t.Fatalf("expected x-ws-enabled=true, got %v", got) + } + if got := spec["x-sse-path"]; got != "/events" { + t.Fatalf("expected x-sse-path=/events, got %v", got) + } + if got := spec["x-sse-enabled"]; got != true { + t.Fatalf("expected x-sse-enabled=true, got %v", got) + } + if got := spec["x-pprof-enabled"]; got != true { + t.Fatalf("expected x-pprof-enabled=true, got %v", got) + } + if got := spec["x-expvar-enabled"]; got != true { + t.Fatalf("expected x-expvar-enabled=true, got %v", got) + } + if got := spec["x-cache-enabled"]; got != true { + t.Fatalf("expected x-cache-enabled=true, got %v", got) + } + if got := spec["x-cache-ttl"]; got != "5m0s" { + t.Fatalf("expected x-cache-ttl=5m0s, got %v", got) + } + if got := spec["x-cache-max-entries"]; got != float64(42) { + t.Fatalf("expected x-cache-max-entries=42, got %v", got) + } + if got := spec["x-cache-max-bytes"]; got != float64(8192) { + t.Fatalf("expected x-cache-max-bytes=8192, got %v", got) + } + if got := spec["x-i18n-default-locale"]; got != "en-GB" { + t.Fatalf("expected x-i18n-default-locale=en-GB, got %v", got) + } + locales, ok := spec["x-i18n-supported-locales"].([]any) + if !ok { + t.Fatalf("expected x-i18n-supported-locales array, got %T", spec["x-i18n-supported-locales"]) + } + if len(locales) != 2 || locales[0] != "en-GB" || locales[1] != "fr" { + t.Fatalf("expected supported locales [en-GB fr], got %v", locales) + } + if got := spec["x-authentik-issuer"]; got != "https://auth.example.com" { + t.Fatalf("expected x-authentik-issuer=https://auth.example.com, got %v", got) + } + if got := spec["x-authentik-client-id"]; got != "core-client" { + t.Fatalf("expected x-authentik-client-id=core-client, got %v", got) + } + if got := spec["x-authentik-trusted-proxy"]; got != true { + t.Fatalf("expected x-authentik-trusted-proxy=true, got %v", got) + } + publicPaths, ok := spec["x-authentik-public-paths"].([]any) + if !ok { + t.Fatalf("expected x-authentik-public-paths array, got %T", spec["x-authentik-public-paths"]) + } + if len(publicPaths) != 4 || publicPaths[0] != "/health" || publicPaths[1] != "/swagger" || publicPaths[2] != "/docs" || publicPaths[3] != "/public" { + t.Fatalf("expected public paths [/health /swagger /docs /public], got %v", publicPaths) + } + + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in generated spec") + } + if contact["name"] != "API Support" { + t.Fatalf("expected contact name API Support, got %v", contact["name"]) + } + + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected licence metadata in generated spec") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected licence name EUPL-1.2, got %v", license["name"]) + } + + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService to be preserved, got %v", info["termsOfService"]) + } + + securitySchemes, ok := spec["components"].(map[string]any)["securitySchemes"].(map[string]any) + if !ok { + t.Fatal("expected securitySchemes metadata in generated spec") + } + apiKeyAuth, ok := securitySchemes["apiKeyAuth"].(map[string]any) + if !ok { + t.Fatal("expected apiKeyAuth security scheme in generated spec") + } + if apiKeyAuth["type"] != "apiKey" { + t.Fatalf("expected apiKeyAuth.type=apiKey, got %v", apiKeyAuth["type"]) + } + if apiKeyAuth["in"] != "header" { + t.Fatalf("expected apiKeyAuth.in=header, got %v", apiKeyAuth["in"]) + } + if apiKeyAuth["name"] != "X-API-Key" { + t.Fatalf("expected apiKeyAuth.name=X-API-Key, got %v", apiKeyAuth["name"]) + } + + externalDocs, ok := spec["externalDocs"].(map[string]any) + if !ok { + t.Fatal("expected externalDocs metadata in generated spec") + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected externalDocs url to be preserved, got %v", externalDocs["url"]) + } + + servers, ok := spec["servers"].([]any) + if !ok { + t.Fatalf("expected servers array in generated spec, got %T", spec["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 normalised servers, got %d", len(servers)) + } + if servers[0].(map[string]any)["url"] != "https://api.example.com" { + t.Fatalf("expected first server to be https://api.example.com, got %v", servers[0]) + } + if servers[1].(map[string]any)["url"] != "/" { + t.Fatalf("expected second server to be /, got %v", servers[1]) + } + + paths, ok := spec["paths"].(map[string]any) + if !ok { + t.Fatalf("expected paths object in generated spec, got %T", spec["paths"]) + } + if _, ok := paths["/gql"]; !ok { + t.Fatal("expected GraphQL path from engine metadata in generated spec") + } + if _, ok := paths["/gql/playground"]; !ok { + t.Fatal("expected GraphQL playground path from engine metadata in generated spec") + } + if _, ok := paths["/socket"]; !ok { + t.Fatal("expected custom WebSocket path from engine metadata in generated spec") + } + if _, ok := paths["/events"]; !ok { + t.Fatal("expected SSE path from engine metadata in generated spec") + } + if _, ok := paths["/debug/pprof"]; !ok { + t.Fatal("expected pprof path from engine metadata in generated spec") + } + if _, ok := paths["/debug/vars"]; !ok { + t.Fatal("expected expvar path from engine metadata in generated spec") + } +} + +func TestEngine_Good_SwaggerConfigCarriesEngineMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Engine API", "Engine metadata", "2.0.0"), + api.WithSwaggerSummary("Engine overview"), + api.WithSwaggerTermsOfService("https://example.com/terms"), + api.WithSwaggerContact("API Support", "https://example.com/support", "support@example.com"), + api.WithSwaggerServers("https://api.example.com", "/", "https://api.example.com"), + api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), + api.WithSwaggerSecuritySchemes(map[string]any{ + "apiKeyAuth": map[string]any{ + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + }, + }), + api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.SwaggerConfig() + if !cfg.Enabled { + t.Fatal("expected Swagger to be enabled") + } + if cfg.Path != "" { + t.Fatalf("expected empty Swagger path when none is configured, got %q", cfg.Path) + } + if cfg.Title != "Engine API" { + t.Fatalf("expected title Engine API, got %q", cfg.Title) + } + if cfg.Description != "Engine metadata" { + t.Fatalf("expected description Engine metadata, got %q", cfg.Description) + } + if cfg.Version != "2.0.0" { + t.Fatalf("expected version 2.0.0, got %q", cfg.Version) + } + if cfg.Summary != "Engine overview" { + t.Fatalf("expected summary Engine overview, got %q", cfg.Summary) + } + if cfg.TermsOfService != "https://example.com/terms" { + t.Fatalf("expected termsOfService to be preserved, got %q", cfg.TermsOfService) + } + if cfg.ContactName != "API Support" { + t.Fatalf("expected contact name API Support, got %q", cfg.ContactName) + } + if cfg.LicenseName != "EUPL-1.2" { + t.Fatalf("expected licence name EUPL-1.2, got %q", cfg.LicenseName) + } + if cfg.ExternalDocsURL != "https://example.com/docs" { + t.Fatalf("expected external docs URL https://example.com/docs, got %q", cfg.ExternalDocsURL) + } + if len(cfg.Servers) != 2 { + t.Fatalf("expected 2 normalised servers, got %d", len(cfg.Servers)) + } + if cfg.Servers[0] != "https://api.example.com" { + t.Fatalf("expected first server to be https://api.example.com, got %q", cfg.Servers[0]) + } + if cfg.Servers[1] != "/" { + t.Fatalf("expected second server to be /, got %q", cfg.Servers[1]) + } + + cfgWithPath, err := api.New( + api.WithSwagger("Engine API", "Engine metadata", "2.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + snap := cfgWithPath.SwaggerConfig() + if snap.Path != "/docs" { + t.Fatalf("expected Swagger path /docs, got %q", snap.Path) + } + + apiKeyAuth, ok := cfg.SecuritySchemes["apiKeyAuth"].(map[string]any) + if !ok { + t.Fatal("expected apiKeyAuth security scheme in Swagger config") + } + if apiKeyAuth["name"] != "X-API-Key" { + t.Fatalf("expected apiKeyAuth.name=X-API-Key, got %v", apiKeyAuth["name"]) + } + + cfg.Servers[0] = "https://mutated.example.com" + apiKeyAuth["name"] = "Changed" + + reshot := e.SwaggerConfig() + if reshot.Servers[0] != "https://api.example.com" { + t.Fatalf("expected engine servers to be cloned, got %q", reshot.Servers[0]) + } + reshotScheme, ok := reshot.SecuritySchemes["apiKeyAuth"].(map[string]any) + if !ok { + t.Fatal("expected apiKeyAuth security scheme in cloned Swagger config") + } + if reshotScheme["name"] != "X-API-Key" { + t.Fatalf("expected cloned security scheme name X-API-Key, got %v", reshotScheme["name"]) + } +} + +func TestEngine_Good_SwaggerConfigTrimsRuntimeMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger(" Engine API ", " Engine metadata ", " 2.0.0 "), + api.WithSwaggerSummary(" Engine overview "), + api.WithSwaggerTermsOfService(" https://example.com/terms "), + api.WithSwaggerContact(" API Support ", " https://example.com/support ", " support@example.com "), + api.WithSwaggerLicense(" EUPL-1.2 ", " https://eupl.eu/1.2/en/ "), + api.WithSwaggerExternalDocs(" Developer guide ", " https://example.com/docs "), + api.WithAuthentik(api.AuthentikConfig{ + Issuer: " https://auth.example.com ", + ClientID: " core-client ", + TrustedProxy: true, + PublicPaths: []string{" /public/ ", " docs ", "/public"}, + }), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + swagger := e.SwaggerConfig() + if swagger.Title != "Engine API" { + t.Fatalf("expected trimmed title Engine API, got %q", swagger.Title) + } + if swagger.Description != "Engine metadata" { + t.Fatalf("expected trimmed description Engine metadata, got %q", swagger.Description) + } + if swagger.Version != "2.0.0" { + t.Fatalf("expected trimmed version 2.0.0, got %q", swagger.Version) + } + if swagger.Summary != "Engine overview" { + t.Fatalf("expected trimmed summary Engine overview, got %q", swagger.Summary) + } + if swagger.TermsOfService != "https://example.com/terms" { + t.Fatalf("expected trimmed termsOfService, got %q", swagger.TermsOfService) + } + if swagger.ContactName != "API Support" || swagger.ContactURL != "https://example.com/support" || swagger.ContactEmail != "support@example.com" { + t.Fatalf("expected trimmed contact metadata, got %+v", swagger) + } + if swagger.LicenseName != "EUPL-1.2" || swagger.LicenseURL != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected trimmed licence metadata, got %+v", swagger) + } + if swagger.ExternalDocsDescription != "Developer guide" || swagger.ExternalDocsURL != "https://example.com/docs" { + t.Fatalf("expected trimmed external docs metadata, got %+v", swagger) + } + + auth := e.AuthentikConfig() + if auth.Issuer != "https://auth.example.com" { + t.Fatalf("expected trimmed issuer, got %q", auth.Issuer) + } + if auth.ClientID != "core-client" { + t.Fatalf("expected trimmed client ID, got %q", auth.ClientID) + } + if want := []string{"/public", "/docs"}; !slices.Equal(auth.PublicPaths, want) { + t.Fatalf("expected trimmed public paths %v, got %v", want, auth.PublicPaths) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info, ok := spec["info"].(map[string]any) + if !ok { + t.Fatal("expected info object in generated spec") + } + if info["title"] != "Engine API" || info["description"] != "Engine metadata" || info["version"] != "2.0.0" || info["summary"] != "Engine overview" { + t.Fatalf("expected trimmed OpenAPI info block, got %+v", info) + } + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected trimmed termsOfService in spec, got %v", info["termsOfService"]) + } +} + +func TestEngine_Good_TransportConfigCarriesEngineMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithSwagger("Engine API", "Engine metadata", "2.0.0"), + api.WithSwaggerPath("/docs"), + api.WithWSPath("/socket"), + api.WithWSHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})), + api.WithGraphQL(newTestSchema(), api.WithPlayground(), api.WithGraphQLPath("/gql")), + api.WithSSE(broker), + api.WithSSEPath("/events"), + api.WithPprof(), + api.WithExpvar(), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.TransportConfig() + if !cfg.SwaggerEnabled { + t.Fatal("expected Swagger to be enabled") + } + if cfg.SwaggerPath != "/docs" { + t.Fatalf("expected swagger path /docs, got %q", cfg.SwaggerPath) + } + if cfg.GraphQLPath != "/gql" { + t.Fatalf("expected graphql path /gql, got %q", cfg.GraphQLPath) + } + if !cfg.GraphQLEnabled { + t.Fatal("expected GraphQL to be enabled") + } + if !cfg.GraphQLPlayground { + t.Fatal("expected GraphQL playground to be enabled") + } + if !cfg.WSEnabled { + t.Fatal("expected WebSocket to be enabled") + } + if cfg.WSPath != "/socket" { + t.Fatalf("expected ws path /socket, got %q", cfg.WSPath) + } + if !cfg.SSEEnabled { + t.Fatal("expected SSE to be enabled") + } + if cfg.SSEPath != "/events" { + t.Fatalf("expected sse path /events, got %q", cfg.SSEPath) + } + if !cfg.PprofEnabled { + t.Fatal("expected pprof to be enabled") + } + if !cfg.ExpvarEnabled { + t.Fatal("expected expvar to be enabled") + } +} + +func TestEngine_Good_TransportConfigReportsDisabledSwaggerWithoutUI(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSwaggerPath("/docs")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + cfg := e.TransportConfig() + if cfg.SwaggerEnabled { + t.Fatal("expected Swagger to remain disabled when only the path is configured") + } + if cfg.SwaggerPath != "/docs" { + t.Fatalf("expected swagger path /docs, got %q", cfg.SwaggerPath) + } +} + +func TestEngine_Good_OpenAPISpecBuilderExportsDefaultSwaggerPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSwagger("Engine API", "Engine metadata", "2.0.0")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-swagger-ui-path"]; got != "/swagger" { + t.Fatalf("expected default x-swagger-ui-path=/swagger, got %v", got) + } +} + +func TestEngine_Good_OpenAPISpecBuilderCarriesExplicitSwaggerPathWithoutUI(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSwaggerPath("/docs")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-swagger-ui-path"]; got != "/docs" { + t.Fatalf("expected explicit x-swagger-ui-path=/docs, got %v", got) + } +} + +func TestEngine_Good_OpenAPISpecBuilderCarriesConfiguredWSPathWithoutHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithWSPath("/socket")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-ws-path"]; got != "/socket" { + t.Fatalf("expected x-ws-path=/socket, got %v", got) + } +} + +func TestEngine_Good_OpenAPISpecBuilderCarriesConfiguredSSEPathWithoutBroker(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSSE(nil), api.WithSSEPath("/events")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + builder := e.OpenAPISpecBuilder() + data, err := builder.Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + if got := spec["x-sse-path"]; got != "/events" { + t.Fatalf("expected x-sse-path=/events, got %v", got) + } +} + +func TestEngine_Good_OpenAPISpecBuilderClonesSecuritySchemes(t *testing.T) { + gin.SetMode(gin.TestMode) + + securityScheme := map[string]any{ + "type": "oauth2", + "flows": map[string]any{ + "clientCredentials": map[string]any{ + "tokenUrl": "https://auth.example.com/token", + }, + }, + } + schemes := map[string]any{ + "oauth2": securityScheme, + } + + e, err := api.New( + api.WithSwagger("Engine API", "Engine metadata", "2.0.0"), + api.WithSwaggerSecuritySchemes(schemes), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Mutate the original input after configuration. The builder snapshot should + // remain stable and keep the original token URL. + securityScheme["type"] = "mutated" + securityScheme["flows"].(map[string]any)["clientCredentials"].(map[string]any)["tokenUrl"] = "https://mutated.example.com/token" + + data, err := e.OpenAPISpecBuilder().Build(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + securitySchemes := spec["components"].(map[string]any)["securitySchemes"].(map[string]any) + oauth2, ok := securitySchemes["oauth2"].(map[string]any) + if !ok { + t.Fatal("expected oauth2 security scheme in generated spec") + } + if oauth2["type"] != "oauth2" { + t.Fatalf("expected cloned oauth2.type=oauth2, got %v", oauth2["type"]) + } + flows := oauth2["flows"].(map[string]any) + clientCredentials := flows["clientCredentials"].(map[string]any) + if clientCredentials["tokenUrl"] != "https://auth.example.com/token" { + t.Fatalf("expected original tokenUrl to be preserved, got %v", clientCredentials["tokenUrl"]) + } +} diff --git a/spec_registry.go b/spec_registry.go new file mode 100644 index 0000000..44902d7 --- /dev/null +++ b/spec_registry.go @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "iter" + "sync" + + "slices" +) + +// specRegistry stores RouteGroups that should be included in CLI-generated +// OpenAPI documents. Packages can register their groups during init and the +// API CLI will pick them up when building specs or SDKs. +var specRegistry struct { + mu sync.RWMutex + groups []RouteGroup +} + +// RegisterSpecGroups adds route groups to the package-level spec registry. +// Nil groups are ignored. Registered groups are returned by RegisteredSpecGroups +// in the order they were added. +// +// Example: +// +// api.RegisterSpecGroups(api.NewToolBridge("/mcp")) +func RegisterSpecGroups(groups ...RouteGroup) { + RegisterSpecGroupsIter(slices.Values(groups)) +} + +// RegisterSpecGroupsIter adds route groups from an iterator to the package-level +// spec registry. +// +// Nil groups are ignored. Registered groups are returned by RegisteredSpecGroups +// in the order they were added. +// +// Example: +// +// api.RegisterSpecGroupsIter(api.RegisteredSpecGroupsIter()) +func RegisterSpecGroupsIter(groups iter.Seq[RouteGroup]) { + if groups == nil { + return + } + + specRegistry.mu.Lock() + defer specRegistry.mu.Unlock() + + for group := range groups { + if group == nil { + continue + } + if specRegistryContains(group) { + continue + } + specRegistry.groups = append(specRegistry.groups, group) + } +} + +// RegisteredSpecGroups returns a copy of the route groups registered for +// CLI-generated OpenAPI documents. +// +// Example: +// +// groups := api.RegisteredSpecGroups() +func RegisteredSpecGroups() []RouteGroup { + specRegistry.mu.RLock() + defer specRegistry.mu.RUnlock() + + out := make([]RouteGroup, len(specRegistry.groups)) + copy(out, specRegistry.groups) + return out +} + +// RegisteredSpecGroupsIter returns an iterator over the route groups registered +// for CLI-generated OpenAPI documents. +// +// The iterator snapshots the current registry contents so callers can range +// over it without holding the registry lock. +// +// Example: +// +// for g := range api.RegisteredSpecGroupsIter() { +// _ = g +// } +func RegisteredSpecGroupsIter() iter.Seq[RouteGroup] { + specRegistry.mu.RLock() + groups := slices.Clone(specRegistry.groups) + specRegistry.mu.RUnlock() + + return slices.Values(groups) +} + +// SpecGroupsIter returns the registered spec groups plus one optional extra +// group, deduplicated by group identity. +// +// The iterator snapshots the registry before yielding so callers can range +// over it without holding the registry lock. +// +// Example: +// +// for g := range api.SpecGroupsIter(api.NewToolBridge("/tools")) { +// _ = g +// } +func SpecGroupsIter(extra RouteGroup) iter.Seq[RouteGroup] { + return func(yield func(RouteGroup) bool) { + seen := map[string]struct{}{} + for group := range RegisteredSpecGroupsIter() { + key := specGroupKey(group) + seen[key] = struct{}{} + if !yield(group) { + return + } + } + if extra != nil { + if _, ok := seen[specGroupKey(extra)]; ok { + return + } + if !yield(extra) { + return + } + } + } +} + +// ResetSpecGroups clears the package-level spec registry. +// It is primarily intended for tests that need to isolate global state. +// +// Example: +// +// api.ResetSpecGroups() +func ResetSpecGroups() { + specRegistry.mu.Lock() + defer specRegistry.mu.Unlock() + + specRegistry.groups = nil +} + +func specRegistryContains(group RouteGroup) bool { + key := specGroupKey(group) + for _, existing := range specRegistry.groups { + if specGroupKey(existing) == key { + return true + } + } + return false +} + +func specGroupKey(group RouteGroup) string { + if group == nil { + return "" + } + + return group.Name() + "\x00" + group.BasePath() +} diff --git a/spec_registry_test.go b/spec_registry_test.go new file mode 100644 index 0000000..b144653 --- /dev/null +++ b/spec_registry_test.go @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "iter" + "testing" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/core/api" +) + +type specRegistryStubGroup struct { + name string + basePath string +} + +func (g *specRegistryStubGroup) Name() string { return g.name } +func (g *specRegistryStubGroup) BasePath() string { return g.basePath } +func (g *specRegistryStubGroup) RegisterRoutes(rg *gin.RouterGroup) {} + +func TestRegisterSpecGroups_Good_DeduplicatesByIdentity(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + first := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + second := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + third := &specRegistryStubGroup{name: "beta", basePath: "/beta"} + + api.RegisterSpecGroups(nil, first, second, third, first) + + groups := api.RegisteredSpecGroups() + if len(groups) != 2 { + t.Fatalf("expected 2 unique groups, got %d", len(groups)) + } + + if groups[0].Name() != "alpha" || groups[0].BasePath() != "/alpha" { + t.Fatalf("expected first group to be alpha at /alpha, got %s at %s", groups[0].Name(), groups[0].BasePath()) + } + if groups[1].Name() != "beta" || groups[1].BasePath() != "/beta" { + t.Fatalf("expected second group to be beta at /beta, got %s at %s", groups[1].Name(), groups[1].BasePath()) + } +} + +func TestRegisterSpecGroups_Good_IteratorReturnsSnapshot(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + first := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + second := &specRegistryStubGroup{name: "beta", basePath: "/beta"} + + api.RegisterSpecGroups(first) + + iter := api.RegisteredSpecGroupsIter() + + api.RegisterSpecGroups(second) + + var groups []api.RouteGroup + for group := range iter { + groups = append(groups, group) + } + + if len(groups) != 1 { + t.Fatalf("expected iterator snapshot to contain 1 group, got %d", len(groups)) + } + if groups[0].Name() != "alpha" || groups[0].BasePath() != "/alpha" { + t.Fatalf("expected iterator snapshot to preserve alpha at /alpha, got %s at %s", groups[0].Name(), groups[0].BasePath()) + } +} + +func TestRegisterSpecGroupsIter_Good_DeduplicatesAndRegisters(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + first := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + second := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + third := &specRegistryStubGroup{name: "gamma", basePath: "/gamma"} + + groups := iter.Seq[api.RouteGroup](func(yield func(api.RouteGroup) bool) { + for _, group := range []api.RouteGroup{first, second, nil, third, first} { + if !yield(group) { + return + } + } + }) + + api.RegisterSpecGroupsIter(groups) + + registered := api.RegisteredSpecGroups() + if len(registered) != 2 { + t.Fatalf("expected 2 unique groups, got %d", len(registered)) + } + if registered[0].Name() != "alpha" || registered[0].BasePath() != "/alpha" { + t.Fatalf("expected first group to be alpha at /alpha, got %s at %s", registered[0].Name(), registered[0].BasePath()) + } + if registered[1].Name() != "gamma" || registered[1].BasePath() != "/gamma" { + t.Fatalf("expected second group to be gamma at /gamma, got %s at %s", registered[1].Name(), registered[1].BasePath()) + } +} + +func TestSpecGroupsIter_Good_DeduplicatesExtraBridge(t *testing.T) { + snapshot := api.RegisteredSpecGroups() + api.ResetSpecGroups() + t.Cleanup(func() { + api.ResetSpecGroups() + api.RegisterSpecGroups(snapshot...) + }) + + first := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + extra := &specRegistryStubGroup{name: "alpha", basePath: "/alpha"} + + api.RegisterSpecGroups(first) + + var groups []api.RouteGroup + for group := range api.SpecGroupsIter(extra) { + groups = append(groups, group) + } + + if len(groups) != 1 { + t.Fatalf("expected deduplicated iterator to return 1 group, got %d", len(groups)) + } + if groups[0].Name() != "alpha" || groups[0].BasePath() != "/alpha" { + t.Fatalf("expected alpha at /alpha, got %s at %s", groups[0].Name(), groups[0].BasePath()) + } +} diff --git a/src/php/src/Api/Concerns/HasApiResponses.php b/src/php/src/Api/Concerns/HasApiResponses.php index 3ab973b..e63de74 100644 --- a/src/php/src/Api/Concerns/HasApiResponses.php +++ b/src/php/src/Api/Concerns/HasApiResponses.php @@ -1,5 +1,7 @@ json(array_merge([ + 'success' => false, + 'error' => $errorCode, + 'message' => $message, + 'error_code' => $errorCode, + ], $meta), $status); + } + /** * Return a no workspace response. */ protected function noWorkspaceResponse(): JsonResponse { - return response()->json([ - 'error' => 'no_workspace', - 'message' => 'No workspace found. Please select a workspace first.', - ], 404); + return $this->errorResponse( + errorCode: 'no_workspace', + message: 'No workspace found. Please select a workspace first.', + status: 404, + ); } /** @@ -27,10 +47,14 @@ protected function noWorkspaceResponse(): JsonResponse */ protected function notFoundResponse(string $resource = 'Resource'): JsonResponse { - return response()->json([ - 'error' => 'not_found', - 'message' => "{$resource} not found.", - ], 404); + return $this->errorResponse( + errorCode: 'not_found', + message: "{$resource} not found.", + meta: [ + 'resource' => $resource, + ], + status: 404, + ); } /** @@ -38,12 +62,15 @@ protected function notFoundResponse(string $resource = 'Resource'): JsonResponse */ protected function limitReachedResponse(string $feature, ?string $message = null): JsonResponse { - return response()->json([ - 'error' => 'feature_limit_reached', - 'message' => $message ?? 'You have reached your limit for this feature.', - 'feature' => $feature, - 'upgrade_url' => route('hub.usage'), - ], 403); + return $this->errorResponse( + errorCode: 'feature_limit_reached', + message: $message ?? 'You have reached your limit for this feature.', + meta: [ + 'feature' => $feature, + 'upgrade_url' => route('hub.usage'), + ], + status: 403, + ); } /** @@ -51,10 +78,20 @@ protected function limitReachedResponse(string $feature, ?string $message = null */ protected function accessDeniedResponse(string $message = 'Access denied.'): JsonResponse { - return response()->json([ - 'error' => 'access_denied', - 'message' => $message, - ], 403); + return $this->forbiddenResponse($message, status: 403); + } + + /** + * Return a forbidden response. + */ + protected function forbiddenResponse(string $message, array $meta = [], int $status = 403): JsonResponse + { + return $this->errorResponse( + errorCode: 'forbidden', + message: $message, + meta: $meta, + status: $status, + ); } /** @@ -63,6 +100,7 @@ protected function accessDeniedResponse(string $message = 'Access denied.'): Jso protected function successResponse(string $message, array $data = []): JsonResponse { return response()->json(array_merge([ + 'success' => true, 'message' => $message, ], $data)); } @@ -73,6 +111,7 @@ protected function successResponse(string $message, array $data = []): JsonRespo protected function createdResponse(mixed $resource, string $message = 'Created successfully.'): JsonResponse { return response()->json([ + 'success' => true, 'message' => $message, 'data' => $resource, ], 201); @@ -81,13 +120,16 @@ protected function createdResponse(mixed $resource, string $message = 'Created s /** * Return a validation error response. */ - protected function validationErrorResponse(array $errors): JsonResponse + protected function validationErrorResponse(array $errors, int $status = 422): JsonResponse { - return response()->json([ - 'error' => 'validation_failed', - 'message' => 'The given data was invalid.', - 'errors' => $errors, - ], 422); + return $this->errorResponse( + errorCode: 'validation_failed', + message: 'The given data was invalid.', + meta: [ + 'errors' => $errors, + ], + status: $status, + ); } /** @@ -97,10 +139,11 @@ protected function validationErrorResponse(array $errors): JsonResponse */ protected function invalidStatusResponse(string $message): JsonResponse { - return response()->json([ - 'error' => 'invalid_status', - 'message' => $message, - ], 422); + return $this->errorResponse( + errorCode: 'invalid_status', + message: $message, + status: 422, + ); } /** @@ -110,15 +153,13 @@ protected function invalidStatusResponse(string $message): JsonResponse */ protected function providerErrorResponse(string $message, ?string $provider = null): JsonResponse { - $response = [ - 'error' => 'provider_error', - 'message' => $message, - ]; - - if ($provider !== null) { - $response['provider'] = $provider; - } - - return response()->json($response, 400); + return $this->errorResponse( + errorCode: 'provider_error', + message: $message, + meta: array_filter([ + 'provider' => $provider, + ]), + status: 400, + ); } } diff --git a/src/php/src/Api/Controllers/Api/EntitlementApiController.php b/src/php/src/Api/Controllers/Api/EntitlementApiController.php new file mode 100644 index 0000000..81fcfb4 --- /dev/null +++ b/src/php/src/Api/Controllers/Api/EntitlementApiController.php @@ -0,0 +1,101 @@ +resolveWorkspace($request); + + if (! $workspace instanceof Workspace) { + return $this->noWorkspaceResponse(); + } + + $apiKey = $request->attributes->get('api_key'); + $authType = $request->attributes->get('auth_type', 'session'); + $rateLimitProfile = $this->resolveRateLimitProfile($authType); + $activeApiKeys = ApiKey::query() + ->forWorkspace($workspace->id) + ->active() + ->count(); + + $usage = $this->usageService->getWorkspaceSummary($workspace->id); + + return response()->json([ + 'workspace_id' => $workspace->id, + 'workspace' => [ + 'id' => $workspace->id, + 'name' => $workspace->name ?? null, + ], + 'authentication' => [ + 'type' => $authType, + 'scopes' => $apiKey instanceof ApiKey ? $apiKey->scopes : null, + ], + 'limits' => [ + 'rate_limit' => $rateLimitProfile, + 'api_keys' => [ + 'active' => $activeApiKeys, + 'maximum' => (int) config('api.keys.max_per_workspace', 10), + 'remaining' => max(0, (int) config('api.keys.max_per_workspace', 10) - $activeApiKeys), + ], + 'webhooks' => [ + 'maximum' => (int) config('api.webhooks.max_per_workspace', 5), + ], + ], + 'usage' => $usage, + 'features' => [ + 'pixel' => true, + 'mcp' => true, + 'webhooks' => true, + 'usage_alerts' => (bool) config('api.alerts.enabled', true), + ], + ]); + } + + /** + * Resolve the rate limit profile for the current auth context. + */ + protected function resolveRateLimitProfile(string $authType): array + { + $rateLimits = (array) config('api.rate_limits', []); + $key = $authType === 'session' ? 'default' : 'authenticated'; + $profile = (array) ($rateLimits[$key] ?? []); + + return [ + 'name' => $key, + 'limit' => (int) ($profile['limit'] ?? 0), + 'window' => (int) ($profile['window'] ?? 60), + 'burst' => (float) ($profile['burst'] ?? 1.0), + ]; + } +} diff --git a/src/php/src/Api/Controllers/Api/SeoReportController.php b/src/php/src/Api/Controllers/Api/SeoReportController.php new file mode 100644 index 0000000..35ff9d0 --- /dev/null +++ b/src/php/src/Api/Controllers/Api/SeoReportController.php @@ -0,0 +1,65 @@ +validate([ + 'url' => ['required', 'url'], + ]); + + try { + $report = $this->seoReportService->analyse($validated['url']); + } catch (RuntimeException) { + return $this->errorResponse( + errorCode: 'seo_unavailable', + message: 'Unable to fetch the requested URL.', + meta: [ + 'provider' => 'seo', + ], + status: 502, + ); + } + + return response()->json([ + 'data' => $report, + ]); + } +} diff --git a/src/php/src/Api/Controllers/Api/UnifiedPixelController.php b/src/php/src/Api/Controllers/Api/UnifiedPixelController.php new file mode 100644 index 0000000..3b16ffe --- /dev/null +++ b/src/php/src/Api/Controllers/Api/UnifiedPixelController.php @@ -0,0 +1,64 @@ + transparent GIF + * POST /api/pixel/abc12345 -> 204 No Content + */ + #[ApiResponse( + 200, + null, + 'Transparent 1x1 GIF pixel response', + contentType: 'image/gif', + schema: [ + 'type' => 'string', + 'format' => 'binary', + ], + )] + #[ApiResponse(204, null, 'Accepted without a response body')] + #[RateLimit(limit: 10000, window: 60)] + public function track(Request $request, string $pixelKey): Response + { + if ($request->isMethod('post')) { + return response()->noContent() + ->header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') + ->header('Pragma', 'no-cache') + ->header('Expires', '0'); + } + + $pixel = base64_decode(self::TRANSPARENT_GIF); + + return response($pixel, 200) + ->header('Content-Type', 'image/gif') + ->header('Cache-Control', 'no-store, no-cache, must-revalidate, max-age=0') + ->header('Pragma', 'no-cache') + ->header('Expires', '0') + ->header('Content-Length', (string) strlen($pixel)); + } +} diff --git a/src/php/src/Api/Controllers/Api/WebhookSecretController.php b/src/php/src/Api/Controllers/Api/WebhookSecretController.php index 5dee8c1..e5284be 100644 --- a/src/php/src/Api/Controllers/Api/WebhookSecretController.php +++ b/src/php/src/Api/Controllers/Api/WebhookSecretController.php @@ -4,6 +4,7 @@ namespace Core\Api\Controllers\Api; +use Core\Api\Concerns\HasApiResponses; use Illuminate\Http\JsonResponse; use Illuminate\Http\Request; use Illuminate\Routing\Controller; @@ -16,6 +17,8 @@ */ class WebhookSecretController extends Controller { + use HasApiResponses; + public function __construct( protected WebhookSecretRotationService $rotationService ) {} @@ -28,7 +31,7 @@ public function rotateSocialSecret(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $webhook = Webhook::where('workspace_id', $workspace->id) @@ -36,7 +39,7 @@ public function rotateSocialSecret(Request $request, string $uuid): JsonResponse ->first(); if (! $webhook) { - return response()->json(['error' => 'Webhook not found'], 404); + return $this->notFoundResponse('Webhook'); } $validated = $request->validate([ @@ -66,7 +69,7 @@ public function rotateContentSecret(Request $request, string $uuid): JsonRespons $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $endpoint = ContentWebhookEndpoint::where('workspace_id', $workspace->id) @@ -74,7 +77,7 @@ public function rotateContentSecret(Request $request, string $uuid): JsonRespons ->first(); if (! $endpoint) { - return response()->json(['error' => 'Webhook endpoint not found'], 404); + return $this->notFoundResponse('Webhook endpoint'); } $validated = $request->validate([ @@ -104,7 +107,7 @@ public function socialSecretStatus(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $webhook = Webhook::where('workspace_id', $workspace->id) @@ -112,7 +115,7 @@ public function socialSecretStatus(Request $request, string $uuid): JsonResponse ->first(); if (! $webhook) { - return response()->json(['error' => 'Webhook not found'], 404); + return $this->notFoundResponse('Webhook'); } return response()->json([ @@ -128,7 +131,7 @@ public function contentSecretStatus(Request $request, string $uuid): JsonRespons $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $endpoint = ContentWebhookEndpoint::where('workspace_id', $workspace->id) @@ -136,7 +139,7 @@ public function contentSecretStatus(Request $request, string $uuid): JsonRespons ->first(); if (! $endpoint) { - return response()->json(['error' => 'Webhook endpoint not found'], 404); + return $this->notFoundResponse('Webhook endpoint'); } return response()->json([ @@ -152,7 +155,7 @@ public function invalidateSocialPreviousSecret(Request $request, string $uuid): $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $webhook = Webhook::where('workspace_id', $workspace->id) @@ -160,7 +163,7 @@ public function invalidateSocialPreviousSecret(Request $request, string $uuid): ->first(); if (! $webhook) { - return response()->json(['error' => 'Webhook not found'], 404); + return $this->notFoundResponse('Webhook'); } $this->rotationService->invalidatePreviousSecret($webhook); @@ -179,7 +182,7 @@ public function invalidateContentPreviousSecret(Request $request, string $uuid): $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $endpoint = ContentWebhookEndpoint::where('workspace_id', $workspace->id) @@ -187,7 +190,7 @@ public function invalidateContentPreviousSecret(Request $request, string $uuid): ->first(); if (! $endpoint) { - return response()->json(['error' => 'Webhook endpoint not found'], 404); + return $this->notFoundResponse('Webhook endpoint'); } $this->rotationService->invalidatePreviousSecret($endpoint); @@ -206,7 +209,7 @@ public function updateSocialGracePeriod(Request $request, string $uuid): JsonRes $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $webhook = Webhook::where('workspace_id', $workspace->id) @@ -214,7 +217,7 @@ public function updateSocialGracePeriod(Request $request, string $uuid): JsonRes ->first(); if (! $webhook) { - return response()->json(['error' => 'Webhook not found'], 404); + return $this->notFoundResponse('Webhook'); } $validated = $request->validate([ @@ -240,7 +243,7 @@ public function updateContentGracePeriod(Request $request, string $uuid): JsonRe $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $endpoint = ContentWebhookEndpoint::where('workspace_id', $workspace->id) @@ -248,7 +251,7 @@ public function updateContentGracePeriod(Request $request, string $uuid): JsonRe ->first(); if (! $endpoint) { - return response()->json(['error' => 'Webhook endpoint not found'], 404); + return $this->notFoundResponse('Webhook endpoint'); } $validated = $request->validate([ diff --git a/src/php/src/Api/Controllers/Api/WebhookTemplateController.php b/src/php/src/Api/Controllers/Api/WebhookTemplateController.php index d9f20eb..9078fed 100644 --- a/src/php/src/Api/Controllers/Api/WebhookTemplateController.php +++ b/src/php/src/Api/Controllers/Api/WebhookTemplateController.php @@ -4,6 +4,7 @@ namespace Core\Api\Controllers\Api; +use Core\Api\Concerns\HasApiResponses; use Illuminate\Http\JsonResponse; use Illuminate\Http\Request; use Illuminate\Routing\Controller; @@ -17,6 +18,8 @@ */ class WebhookTemplateController extends Controller { + use HasApiResponses; + public function __construct( protected WebhookTemplateService $templateService ) {} @@ -29,7 +32,7 @@ public function index(Request $request): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $query = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -61,7 +64,7 @@ public function show(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $template = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -69,7 +72,7 @@ public function show(Request $request, string $uuid): JsonResponse ->first(); if (! $template) { - return response()->json(['error' => 'Template not found'], 404); + return $this->notFoundResponse('Template'); } return response()->json([ @@ -85,7 +88,7 @@ public function store(Request $request): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $validated = $request->validate([ @@ -102,10 +105,9 @@ public function store(Request $request): JsonResponse $validation = $this->templateService->validateTemplate($validated['template'], $format); if (! $validation['valid']) { - return response()->json([ - 'error' => 'Invalid template', - 'errors' => $validation['errors'], - ], 422); + return $this->validationErrorResponse([ + 'template' => $validation['errors'], + ]); } $template = WebhookPayloadTemplate::create([ @@ -133,7 +135,7 @@ public function update(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $template = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -141,7 +143,7 @@ public function update(Request $request, string $uuid): JsonResponse ->first(); if (! $template) { - return response()->json(['error' => 'Template not found'], 404); + return $this->notFoundResponse('Template'); } $validated = $request->validate([ @@ -159,10 +161,9 @@ public function update(Request $request, string $uuid): JsonResponse $validation = $this->templateService->validateTemplate($validated['template'], $format); if (! $validation['valid']) { - return response()->json([ - 'error' => 'Invalid template', - 'errors' => $validation['errors'], - ], 422); + return $this->validationErrorResponse([ + 'template' => $validation['errors'], + ]); } } @@ -186,7 +187,7 @@ public function destroy(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $template = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -194,12 +195,12 @@ public function destroy(Request $request, string $uuid): JsonResponse ->first(); if (! $template) { - return response()->json(['error' => 'Template not found'], 404); + return $this->notFoundResponse('Template'); } // Don't allow deleting builtin templates if ($template->isBuiltin()) { - return response()->json(['error' => 'Built-in templates cannot be deleted'], 403); + return $this->forbiddenResponse('Built-in templates cannot be deleted'); } $template->delete(); @@ -255,7 +256,7 @@ public function duplicate(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $template = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -263,7 +264,7 @@ public function duplicate(Request $request, string $uuid): JsonResponse ->first(); if (! $template) { - return response()->json(['error' => 'Template not found'], 404); + return $this->notFoundResponse('Template'); } $newName = $request->input('name', $template->name.' (copy)'); @@ -282,7 +283,7 @@ public function setDefault(Request $request, string $uuid): JsonResponse $workspace = $request->user()?->defaultHostWorkspace(); if (! $workspace) { - return response()->json(['error' => 'Workspace not found'], 404); + return $this->noWorkspaceResponse(); } $template = WebhookPayloadTemplate::where('workspace_id', $workspace->id) @@ -290,7 +291,7 @@ public function setDefault(Request $request, string $uuid): JsonResponse ->first(); if (! $template) { - return response()->json(['error' => 'Template not found'], 404); + return $this->notFoundResponse('Template'); } $template->setAsDefault(); diff --git a/src/php/src/Api/Controllers/McpApiController.php b/src/php/src/Api/Controllers/McpApiController.php index 828e85b..574e8cc 100644 --- a/src/php/src/Api/Controllers/McpApiController.php +++ b/src/php/src/Api/Controllers/McpApiController.php @@ -5,6 +5,8 @@ namespace Core\Api\Controllers; use Core\Front\Controller; +use Core\Api\Concerns\HasApiResponses; +use Core\Api\Documentation\Attributes\ApiParameter; use Core\Api\Models\ApiKey; use Core\Mod\Mcp\Models\McpApiRequest; use Core\Mod\Mcp\Models\McpToolCall; @@ -23,6 +25,8 @@ */ class McpApiController extends Controller { + use HasApiResponses; + /** * List all available MCP servers. * @@ -47,15 +51,48 @@ public function servers(Request $request): JsonResponse * Get server details with tools and resources. * * GET /api/v1/mcp/servers/{id} + * + * Query params: + * - include_versions: bool - include version info for each tool + * - include_content: bool - include resource content when the definition already contains it */ + #[ApiParameter( + name: 'include_versions', + in: 'query', + type: 'boolean', + description: 'Include version information for each tool', + required: false, + example: false, + default: false + )] + #[ApiParameter( + name: 'include_content', + in: 'query', + type: 'boolean', + description: 'Include resource content when the definition already contains it', + required: false, + example: false, + default: false + )] public function server(Request $request, string $id): JsonResponse { $server = $this->loadServerFull($id); if (! $server) { - return response()->json(['error' => 'Server not found'], 404); + return $this->notFoundResponse('Server'); + } + + if ($request->boolean('include_versions', false)) { + $server['tools'] = $this->enrichToolsWithVersioning($id, $server['tools'] ?? []); } + if ($request->boolean('include_content', false)) { + $server['resources'] = $this->enrichResourcesWithContent($server['resources'] ?? []); + } + + $server['tool_count'] = count($server['tools'] ?? []); + $server['resource_count'] = count($server['resources'] ?? []); + return response()->json($server); } @@ -67,12 +104,21 @@ public function server(Request $request, string $id): JsonResponse * Query params: * - include_versions: bool - include version info for each tool */ + #[ApiParameter( + name: 'include_versions', + in: 'query', + type: 'boolean', + description: 'Include version information for each tool', + required: false, + example: false, + default: false + )] public function tools(Request $request, string $id): JsonResponse { $server = $this->loadServerFull($id); if (! $server) { - return response()->json(['error' => 'Server not found'], 404); + return $this->notFoundResponse('Server'); } $tools = $server['tools'] ?? []; @@ -107,6 +153,116 @@ public function tools(Request $request, string $id): JsonResponse ]); } + /** + * List resources for a specific server. + * + * GET /api/v1/mcp/servers/{id}/resources + * + * Query params: + * - include_content: bool - include resource content when the definition already contains it + */ + #[ApiParameter( + name: 'include_content', + in: 'query', + type: 'boolean', + description: 'Include resource content when the definition already contains it', + required: false, + example: false, + default: false + )] + public function resources(Request $request, string $id): JsonResponse + { + $server = $this->loadServerFull($id); + + if (! $server) { + return $this->notFoundResponse('Server'); + } + + $includeContent = $request->boolean('include_content', false); + + $resources = collect($server['resources'] ?? []) + ->filter(fn ($resource) => is_array($resource)) + ->map(function (array $resource) use ($includeContent) { + $payload = array_filter([ + 'uri' => $resource['uri'] ?? null, + 'path' => $resource['path'] ?? null, + 'name' => $resource['name'] ?? null, + 'description' => $resource['description'] ?? null, + 'mime_type' => $resource['mime_type'] ?? ($resource['mimeType'] ?? null), + ], static fn ($value) => $value !== null); + + if ($includeContent && $this->resourceDefinitionHasContent($resource)) { + $payload['content'] = $this->normaliseResourceContent($resource); + } + + return $payload; + }) + ->values(); + + return response()->json([ + 'server' => $id, + 'resources' => $resources, + 'count' => $resources->count(), + ]); + } + + /** + * Enrich a tool collection with version metadata. + * + * @param array> $tools + * @return array> + */ + protected function enrichToolsWithVersioning(string $serverId, array $tools): array + { + $versionService = app(ToolVersionService::class); + + return collect($tools)->map(function (array $tool) use ($serverId, $versionService) { + $toolName = $tool['name'] ?? ''; + $latestVersion = $versionService->getLatestVersion($serverId, $toolName); + + $tool['versioning'] = [ + 'latest_version' => $latestVersion?->version ?? ToolVersionService::DEFAULT_VERSION, + 'is_versioned' => $latestVersion !== null, + 'deprecated' => $latestVersion?->is_deprecated ?? false, + ]; + + if ($latestVersion?->input_schema) { + $tool['inputSchema'] = $latestVersion->input_schema; + } + + return $tool; + })->all(); + } + + /** + * Enrich a resource collection with inline content when available. + * + * @param array> $resources + * @return array> + */ + protected function enrichResourcesWithContent(array $resources): array + { + return collect($resources) + ->filter(fn ($resource) => is_array($resource)) + ->map(function (array $resource) { + $payload = array_filter([ + 'uri' => $resource['uri'] ?? null, + 'path' => $resource['path'] ?? null, + 'name' => $resource['name'] ?? null, + 'description' => $resource['description'] ?? null, + 'mime_type' => $resource['mime_type'] ?? ($resource['mimeType'] ?? null), + ], static fn ($value) => $value !== null); + + if ($this->resourceDefinitionHasContent($resource)) { + $payload['content'] = $this->normaliseResourceContent($resource); + } + + return $payload; + }) + ->values() + ->all(); + } + /** * Execute a tool on an MCP server. * @@ -129,13 +285,13 @@ public function callTool(Request $request): JsonResponse $server = $this->loadServerFull($validated['server']); if (! $server) { - return response()->json(['error' => 'Server not found'], 404); + return $this->notFoundResponse('Server'); } // Verify tool exists in server definition $toolDef = collect($server['tools'] ?? [])->firstWhere('name', $validated['tool']); if (! $toolDef) { - return response()->json(['error' => 'Tool not found'], 404); + return $this->notFoundResponse('Tool'); } // Version resolution @@ -153,16 +309,18 @@ public function callTool(Request $request): JsonResponse // Sunset versions return 410 Gone $status = ($error['code'] ?? '') === 'TOOL_VERSION_SUNSET' ? 410 : 400; - return response()->json([ - 'success' => false, - 'error' => $error['message'] ?? 'Version error', - 'error_code' => $error['code'] ?? 'VERSION_ERROR', - 'server' => $validated['server'], - 'tool' => $validated['tool'], - 'requested_version' => $validated['version'] ?? null, - 'latest_version' => $error['latest_version'] ?? null, - 'migration_notes' => $error['migration_notes'] ?? null, - ], $status); + return $this->errorResponse( + errorCode: $error['code'] ?? 'VERSION_ERROR', + message: $error['message'] ?? 'Version error', + meta: [ + 'server' => $validated['server'], + 'tool' => $validated['tool'], + 'requested_version' => $validated['version'] ?? null, + 'latest_version' => $error['latest_version'] ?? null, + 'migration_notes' => $error['migration_notes'] ?? null, + ], + status: $status, + ); } /** @var McpToolVersion|null $toolVersion */ @@ -178,15 +336,17 @@ public function callTool(Request $request): JsonResponse ); if (! empty($validationErrors)) { - return response()->json([ - 'success' => false, - 'error' => 'Validation failed', - 'error_code' => 'VALIDATION_ERROR', - 'validation_errors' => $validationErrors, - 'server' => $validated['server'], - 'tool' => $validated['tool'], - 'version' => $toolVersion?->version ?? 'unversioned', - ], 422); + return $this->errorResponse( + errorCode: 'VALIDATION_ERROR', + message: 'Validation failed', + meta: [ + 'validation_errors' => $validationErrors, + 'server' => $validated['server'], + 'tool' => $validated['tool'], + 'version' => $toolVersion?->version ?? 'unversioned', + ], + status: 422, + ); } } @@ -201,7 +361,8 @@ public function callTool(Request $request): JsonResponse $result = $this->executeToolViaArtisan( $validated['server'], $validated['tool'], - $validated['arguments'] ?? [] + $validated['arguments'] ?? [], + $toolVersion?->version ); $durationMs = (int) ((microtime(true) - $startTime) * 1000); @@ -262,7 +423,16 @@ public function callTool(Request $request): JsonResponse // Log full request for debugging/replay $this->logApiRequest($request, $validated, 500, $response, $durationMs, $apiKey, $e->getMessage()); - return response()->json($response, 500); + return $this->errorResponse( + errorCode: 'tool_execution_error', + message: $e->getMessage(), + meta: array_filter([ + 'server' => $validated['server'], + 'tool' => $validated['tool'], + 'version' => $toolVersion?->version ?? ToolVersionService::DEFAULT_VERSION, + ]), + status: 500, + ); } } @@ -343,13 +513,13 @@ public function toolVersions(Request $request, string $server, string $tool): Js { $serverConfig = $this->loadServerFull($server); if (! $serverConfig) { - return response()->json(['error' => 'Server not found'], 404); + return $this->notFoundResponse('Server'); } // Verify tool exists in server definition $toolDef = collect($serverConfig['tools'] ?? [])->firstWhere('name', $tool); if (! $toolDef) { - return response()->json(['error' => 'Tool not found'], 404); + return $this->notFoundResponse('Tool'); } $versionService = app(ToolVersionService::class); @@ -374,7 +544,7 @@ public function toolVersion(Request $request, string $server, string $tool, stri $toolVersion = $versionService->getToolAtVersion($server, $tool, $version); if (! $toolVersion) { - return response()->json(['error' => 'Version not found'], 404); + return $this->notFoundResponse('Version'); } $response = response()->json($toolVersion->toApiArray()); @@ -397,9 +567,13 @@ public function toolVersion(Request $request, string $server, string $tool, stri */ public function resource(Request $request, string $uri): JsonResponse { + $uri = rawurldecode($uri); + // Parse URI format: server://resource/path if (! preg_match('/^([a-z0-9-]+):\/\/(.+)$/', $uri, $matches)) { - return response()->json(['error' => 'Invalid resource URI format'], 400); + return $this->validationErrorResponse([ + 'uri' => ['Invalid resource URI format. Expected pattern server://resource/path'], + ], 400); } $serverId = $matches[1]; @@ -407,55 +581,139 @@ public function resource(Request $request, string $uri): JsonResponse $server = $this->loadServerFull($serverId); if (! $server) { - return response()->json(['error' => 'Server not found'], 404); + return $this->notFoundResponse('Server'); + } + + $resourceDef = $this->findResourceDefinition($server, $uri, $resourcePath); + if ($resourceDef !== null && $this->resourceDefinitionHasContent($resourceDef)) { + return response()->json([ + 'uri' => $uri, + 'server' => $serverId, + 'resource' => $resourcePath, + 'content' => $this->normaliseResourceContent($resourceDef), + ]); } try { $result = $this->readResourceViaArtisan($serverId, $resourcePath); + if ($result === null) { + return $this->notFoundResponse('Resource'); + } + + if (is_array($result) && array_key_exists('content', $result)) { + $content = $result['content']; + } elseif (is_array($result) && array_key_exists('contents', $result)) { + $content = $result['contents']; + } else { + $content = $result; + } return response()->json([ 'uri' => $uri, - 'content' => $result, + 'server' => $serverId, + 'resource' => $resourcePath, + 'content' => $content, ]); } catch (\Throwable $e) { - return response()->json([ - 'error' => $e->getMessage(), - 'uri' => $uri, - ], 500); + return $this->errorResponse( + errorCode: 'resource_read_error', + message: $e->getMessage(), + meta: [ + 'uri' => $uri, + ], + status: 500, + ); } } /** * Execute tool via artisan MCP server command. */ - protected function executeToolViaArtisan(string $server, string $tool, array $arguments): mixed + protected function executeToolViaArtisan(string $server, string $tool, array $arguments, ?string $version = null): mixed { - $commandMap = [ - 'hosthub-agent' => 'mcp:agent-server', - 'socialhost' => 'mcp:socialhost-server', - 'biohost' => 'mcp:biohost-server', - 'commerce' => 'mcp:commerce-server', - 'supporthost' => 'mcp:support-server', - 'upstream' => 'mcp:upstream-server', + $command = $this->resolveMcpServerCommand($server); + if (! $command) { + throw new \RuntimeException("Unknown server: {$server}"); + } + + $mcpRequest = $this->buildToolCallRequest($tool, $arguments, $version); + + // Execute via process + $process = proc_open( + ['php', 'artisan', $command], + [ + 0 => ['pipe', 'r'], + 1 => ['pipe', 'w'], + 2 => ['pipe', 'w'], + ], + $pipes, + base_path() + ); + + if (! is_resource($process)) { + throw new \RuntimeException('Failed to start MCP server process'); + } + + fwrite($pipes[0], json_encode($mcpRequest)."\n"); + fclose($pipes[0]); + + $output = stream_get_contents($pipes[1]); + fclose($pipes[1]); + fclose($pipes[2]); + + proc_close($process); + + $response = json_decode($output, true); + + if (isset($response['error'])) { + throw new \RuntimeException($response['error']['message'] ?? 'Tool execution failed'); + } + + return $response['result'] ?? null; + } + + /** + * Build the JSON-RPC payload for an MCP tool call. + */ + protected function buildToolCallRequest(string $tool, array $arguments, ?string $version = null): array + { + $params = [ + 'name' => $tool, + 'arguments' => $arguments, ]; - $command = $commandMap[$server] ?? null; + if ($version !== null && $version !== '') { + $params['version'] = $version; + } + + return [ + 'jsonrpc' => '2.0', + 'id' => uniqid(), + 'method' => 'tools/call', + 'params' => $params, + ]; + } + + /** + * Read resource via artisan MCP server command. + */ + protected function readResourceViaArtisan(string $server, string $path): mixed + { + $command = $this->resolveMcpServerCommand($server); if (! $command) { throw new \RuntimeException("Unknown server: {$server}"); } - // Build MCP request $mcpRequest = [ 'jsonrpc' => '2.0', 'id' => uniqid(), - 'method' => 'tools/call', + 'method' => 'resources/read', 'params' => [ - 'name' => $tool, - 'arguments' => $arguments, + 'uri' => "{$server}://{$path}", + 'path' => $path, ], ]; - // Execute via process $process = proc_open( ['php', 'artisan', $command], [ @@ -481,22 +739,90 @@ protected function executeToolViaArtisan(string $server, string $tool, array $ar proc_close($process); $response = json_decode($output, true); + if (! is_array($response)) { + throw new \RuntimeException('Invalid MCP resource response'); + } if (isset($response['error'])) { - throw new \RuntimeException($response['error']['message'] ?? 'Tool execution failed'); + throw new \RuntimeException($response['error']['message'] ?? 'Resource read failed'); } return $response['result'] ?? null; } /** - * Read resource via artisan MCP server command. + * Resolve the artisan command used for a given MCP server. */ - protected function readResourceViaArtisan(string $server, string $path): mixed + protected function resolveMcpServerCommand(string $server): ?string + { + $commandMap = [ + 'hosthub-agent' => 'mcp:agent-server', + 'socialhost' => 'mcp:socialhost-server', + 'biohost' => 'mcp:biohost-server', + 'commerce' => 'mcp:commerce-server', + 'supporthost' => 'mcp:support-server', + 'upstream' => 'mcp:upstream-server', + ]; + + return $commandMap[$server] ?? null; + } + + /** + * Find a resource definition within the loaded server config. + */ + protected function findResourceDefinition(array $server, string $uri, string $path): mixed + { + foreach ($server['resources'] ?? [] as $resource) { + if (! is_array($resource)) { + continue; + } + + $resourceUri = $resource['uri'] ?? null; + $resourcePath = $resource['path'] ?? null; + $resourceName = $resource['name'] ?? null; + + if ($resourceUri === $uri || $resourcePath === $path || $resourceName === basename($path)) { + return $resource; + } + } + + return null; + } + + /** + * Normalise a resource definition into a response payload. + */ + protected function normaliseResourceContent(mixed $resource): mixed { - // Similar to executeToolViaArtisan but with resources/read method - // Simplified for now - can expand later - return ['path' => $path, 'content' => 'Resource reading not yet implemented']; + if (! is_array($resource)) { + return $resource; + } + + foreach (['content', 'contents', 'body', 'text', 'value'] as $field) { + if (array_key_exists($field, $resource)) { + return $resource[$field]; + } + } + + return $resource; + } + + /** + * Determine whether a resource definition already carries readable content. + */ + protected function resourceDefinitionHasContent(mixed $resource): bool + { + if (! is_array($resource)) { + return true; + } + + foreach (['content', 'contents', 'body', 'text', 'value'] as $field) { + if (array_key_exists($field, $resource)) { + return true; + } + } + + return false; } /** diff --git a/src/php/src/Api/Documentation/Attributes/ApiResponse.php b/src/php/src/Api/Documentation/Attributes/ApiResponse.php index 222350b..6958c69 100644 --- a/src/php/src/Api/Documentation/Attributes/ApiResponse.php +++ b/src/php/src/Api/Documentation/Attributes/ApiResponse.php @@ -27,6 +27,19 @@ * { * return UserResource::collection(User::paginate()); * } + * + * // For non-JSON or binary responses + * #[ApiResponse( + * 200, + * null, + * 'Transparent tracking pixel', + * contentType: 'image/gif', + * schema: ['type' => 'string', 'format' => 'binary'] + * )] + * public function pixel() + * { + * return response($gif, 200)->header('Content-Type', 'image/gif'); + * } */ #[Attribute(Attribute::TARGET_METHOD | Attribute::IS_REPEATABLE)] readonly class ApiResponse @@ -37,6 +50,8 @@ * @param string|null $description Description of the response * @param bool $paginated Whether this is a paginated collection response * @param array $headers Additional response headers to document + * @param string|null $contentType Explicit response media type for non-JSON responses + * @param array|null $schema Explicit response schema when the body is not inferred from a resource */ public function __construct( public int $status, @@ -44,6 +59,8 @@ public function __construct( public ?string $description = null, public bool $paginated = false, public array $headers = [], + public ?string $contentType = null, + public ?array $schema = null, ) {} /** @@ -64,10 +81,11 @@ public function getDescription(): string 302 => 'Found (redirect)', 304 => 'Not modified', 400 => 'Bad request', - 401 => 'Unauthorized', + 401 => 'Unauthorised', 403 => 'Forbidden', 404 => 'Not found', 405 => 'Method not allowed', + 410 => 'Gone', 409 => 'Conflict', 422 => 'Validation error', 429 => 'Too many requests', diff --git a/src/php/src/Api/Documentation/DocumentationController.php b/src/php/src/Api/Documentation/DocumentationController.php index cca3d08..8c7a3ad 100644 --- a/src/php/src/Api/Documentation/DocumentationController.php +++ b/src/php/src/Api/Documentation/DocumentationController.php @@ -34,6 +34,7 @@ public function index(Request $request): View return match ($defaultUi) { 'swagger' => $this->swagger($request), 'redoc' => $this->redoc($request), + 'stoplight' => $this->stoplight($request), default => $this->scalar($request), }; } @@ -74,6 +75,19 @@ public function redoc(Request $request): View ]); } + /** + * Show Stoplight Elements. + */ + public function stoplight(Request $request): View + { + $config = config('api-docs.ui.stoplight', []); + + return view('api-docs::stoplight', [ + 'specUrl' => route('api.docs.openapi.json'), + 'config' => $config, + ]); + } + /** * Get OpenAPI specification as JSON. */ diff --git a/src/php/src/Api/Documentation/Extensions/ApiKeyAuthExtension.php b/src/php/src/Api/Documentation/Extensions/ApiKeyAuthExtension.php index baf0259..0344700 100644 --- a/src/php/src/Api/Documentation/Extensions/ApiKeyAuthExtension.php +++ b/src/php/src/Api/Documentation/Extensions/ApiKeyAuthExtension.php @@ -53,7 +53,7 @@ public function extend(array $spec, array $config): array 'properties' => [ 'message' => [ 'type' => 'string', - 'example' => 'This action is unauthorized.', + 'example' => 'This action is unauthorised.', ], ], ]; diff --git a/src/php/src/Api/Documentation/Extensions/SunsetExtension.php b/src/php/src/Api/Documentation/Extensions/SunsetExtension.php new file mode 100644 index 0000000..d20709e --- /dev/null +++ b/src/php/src/Api/Documentation/Extensions/SunsetExtension.php @@ -0,0 +1,147 @@ + 'Indicates that the endpoint is deprecated.', + 'schema' => [ + 'type' => 'string', + 'enum' => ['true'], + ], + ]; + + $spec['components']['headers']['sunset'] = [ + 'description' => 'The date and time after which the endpoint will no longer be supported.', + 'schema' => [ + 'type' => 'string', + 'format' => 'date-time', + ], + ]; + + $spec['components']['headers']['link'] = [ + 'description' => 'Reference to the successor endpoint, when one is provided.', + 'schema' => [ + 'type' => 'string', + ], + ]; + + $spec['components']['headers']['xapiwarn'] = [ + 'description' => 'Human-readable deprecation warning for clients.', + 'schema' => [ + 'type' => 'string', + ], + ]; + + return $spec; + } + + /** + * Extend an individual operation. + */ + public function extendOperation(array $operation, Route $route, string $method, array $config): array + { + $sunset = $this->sunsetMiddlewareArguments($route); + + if ($sunset === null) { + return $operation; + } + + $operation['deprecated'] = true; + + foreach ($operation['responses'] as $status => &$response) { + if (! is_numeric($status) || (int) $status < 200 || (int) $status >= 300) { + continue; + } + + $response['headers'] = $response['headers'] ?? []; + + $response['headers']['Deprecation'] = [ + '$ref' => '#/components/headers/deprecation', + ]; + if ($sunset['sunsetDate'] !== null && $sunset['sunsetDate'] !== '') { + $response['headers']['Sunset'] = [ + '$ref' => '#/components/headers/sunset', + ]; + } + $response['headers']['X-API-Warn'] = [ + '$ref' => '#/components/headers/xapiwarn', + ]; + + if ( + $sunset['replacement'] !== null + && $sunset['replacement'] !== '' + && ! isset($response['headers']['Link']) + ) { + $response['headers']['Link'] = [ + '$ref' => '#/components/headers/link', + ]; + } + } + unset($response); + + return $operation; + } + + /** + * Extract the configured sunset middleware arguments from a route. + * + * Returns null when the route does not use the sunset middleware. + * + * @return array{sunsetDate:?string,replacement:?string}|null + */ + protected function sunsetMiddlewareArguments(Route $route): ?array + { + foreach ($route->middleware() as $middleware) { + if (! str_starts_with($middleware, 'api.sunset') && ! str_contains($middleware, 'ApiSunset')) { + continue; + } + + $arguments = null; + + if (str_contains($middleware, ':')) { + [, $arguments] = explode(':', $middleware, 2); + } + + if ($arguments === null || $arguments === '') { + return [ + 'sunsetDate' => null, + 'replacement' => null, + ]; + } + + $parts = explode(',', $arguments, 2); + $sunsetDate = trim($parts[0] ?? ''); + $replacement = isset($parts[1]) ? trim($parts[1]) : null; + if ($replacement === '') { + $replacement = null; + } + + return [ + 'sunsetDate' => $sunsetDate !== '' ? $sunsetDate : null, + 'replacement' => $replacement, + ]; + } + + return null; + } +} diff --git a/src/php/src/Api/Documentation/Extensions/VersionExtension.php b/src/php/src/Api/Documentation/Extensions/VersionExtension.php new file mode 100644 index 0000000..04cead7 --- /dev/null +++ b/src/php/src/Api/Documentation/Extensions/VersionExtension.php @@ -0,0 +1,126 @@ + 'API version used to process the request.', + 'schema' => [ + 'type' => 'string', + ], + ]; + + return $spec; + } + + /** + * Extend an individual operation. + */ + public function extendOperation(array $operation, Route $route, string $method, array $config): array + { + $version = $this->versionMiddlewareVersion($route); + if ($version === null) { + return $operation; + } + + $includeVersion = (bool) config('api.headers.include_version', true); + $includeDeprecation = (bool) config('api.headers.include_deprecation', true); + + $deprecatedVersions = array_map('intval', config('api.versioning.deprecated', [])); + $sunsetDates = config('api.versioning.sunset', []); + $isDeprecatedVersion = in_array($version, $deprecatedVersions, true); + $sunsetDate = $sunsetDates[$version] ?? null; + + if ($isDeprecatedVersion) { + $operation['deprecated'] = true; + } + + foreach ($operation['responses'] as $status => &$response) { + if (! is_numeric($status) || (int) $status < 200 || (int) $status >= 600) { + continue; + } + + $response['headers'] = $response['headers'] ?? []; + + if ($includeVersion && ! isset($response['headers']['X-API-Version'])) { + $response['headers']['X-API-Version'] = [ + '$ref' => '#/components/headers/xapiversion', + ]; + } + + if (! $includeDeprecation || ! $isDeprecatedVersion) { + continue; + } + + $response['headers']['Deprecation'] = [ + '$ref' => '#/components/headers/deprecation', + ]; + $response['headers']['X-API-Warn'] = [ + '$ref' => '#/components/headers/xapiwarn', + ]; + + if ($sunsetDate !== null && $sunsetDate !== '') { + $response['headers']['Sunset'] = [ + '$ref' => '#/components/headers/sunset', + ]; + } + } + unset($response); + + return $operation; + } + + /** + * Extract the version number from api.version middleware. + */ + protected function versionMiddlewareVersion(Route $route): ?int + { + foreach ($route->middleware() as $middleware) { + if (! str_starts_with($middleware, 'api.version') && ! str_contains($middleware, 'ApiVersion')) { + continue; + } + + if (! str_contains($middleware, ':')) { + return null; + } + + [, $arguments] = explode(':', $middleware, 2); + $arguments = trim($arguments); + if ($arguments === '') { + return null; + } + + $parts = explode(',', $arguments, 2); + $version = ltrim(trim($parts[0] ?? ''), 'vV'); + if ($version === '' || ! is_numeric($version)) { + return null; + } + + return (int) $version; + } + + return null; + } +} diff --git a/src/php/src/Api/Documentation/OpenApiBuilder.php b/src/php/src/Api/Documentation/OpenApiBuilder.php index 8a21b8e..4770720 100644 --- a/src/php/src/Api/Documentation/OpenApiBuilder.php +++ b/src/php/src/Api/Documentation/OpenApiBuilder.php @@ -11,6 +11,8 @@ use Core\Api\Documentation\Attributes\ApiTag; use Core\Api\Documentation\Extensions\ApiKeyAuthExtension; use Core\Api\Documentation\Extensions\RateLimitExtension; +use Core\Api\Documentation\Extensions\SunsetExtension; +use Core\Api\Documentation\Extensions\VersionExtension; use Core\Api\Documentation\Extensions\WorkspaceHeaderExtension; use Illuminate\Http\Resources\Json\JsonResource; use Illuminate\Routing\Route; @@ -57,7 +59,9 @@ protected function registerDefaultExtensions(): void { $this->extensions = [ new WorkspaceHeaderExtension, + new VersionExtension, new RateLimitExtension, + new SunsetExtension, new ApiKeyAuthExtension, ]; } @@ -229,6 +233,7 @@ protected function buildTags(array $config): array protected function buildPaths(array $config): array { $paths = []; + $operationIds = []; $includePatterns = $config['routes']['include'] ?? ['api/*']; $excludePatterns = $config['routes']['exclude'] ?? []; @@ -243,7 +248,7 @@ protected function buildPaths(array $config): array foreach ($methods as $method) { $method = strtolower($method); - $operation = $this->buildOperation($route, $method, $config); + $operation = $this->buildOperation($route, $method, $config, $operationIds); if ($operation !== null) { $paths[$path][$method] = $operation; @@ -297,7 +302,7 @@ protected function normalizePath(string $uri): string /** * Build operation for a specific route and method. */ - protected function buildOperation(Route $route, string $method, array $config): ?array + protected function buildOperation(Route $route, string $method, array $config, array &$operationIds): ?array { $controller = $route->getController(); $action = $route->getActionMethod(); @@ -309,7 +314,7 @@ protected function buildOperation(Route $route, string $method, array $config): $operation = [ 'summary' => $this->buildSummary($route, $method), - 'operationId' => $this->buildOperationId($route, $method), + 'operationId' => $this->buildOperationId($route, $method, $operationIds), 'tags' => $this->buildOperationTags($route, $controller, $action), 'responses' => $this->buildResponses($controller, $action), ]; @@ -328,7 +333,7 @@ protected function buildOperation(Route $route, string $method, array $config): // Add request body for POST/PUT/PATCH if (in_array($method, ['post', 'put', 'patch'])) { - $operation['requestBody'] = $this->buildRequestBody($controller, $action); + $operation['requestBody'] = $this->buildRequestBody($route, $controller, $action); } // Add security requirements @@ -398,15 +403,24 @@ protected function buildSummary(Route $route, string $method): string /** * Build operation ID from route name. */ - protected function buildOperationId(Route $route, string $method): string + protected function buildOperationId(Route $route, string $method, array &$operationIds): string { $name = $route->getName(); if ($name) { - return Str::camel(str_replace(['.', '-'], '_', $name)); + $base = Str::camel(str_replace(['.', '-'], '_', $name)); + } else { + $base = Str::camel($method.'_'.str_replace(['/', '-', '.'], '_', $route->uri())); } - return Str::camel($method.'_'.str_replace(['/', '-', '.'], '_', $route->uri())); + $count = $operationIds[$base] ?? 0; + $operationIds[$base] = $count + 1; + + if ($count === 0) { + return $base; + } + + return $base.'_'.($count + 1); } /** @@ -511,16 +525,36 @@ protected function extractDescription(?object $controller, string $action): ?str protected function buildParameters(Route $route, ?object $controller, string $action, array $config): array { $parameters = []; + $parameterIndex = []; + + $addParameter = function (array $parameter) use (&$parameters, &$parameterIndex): void { + $name = $parameter['name'] ?? null; + $in = $parameter['in'] ?? null; + + if (! is_string($name) || $name === '' || ! is_string($in) || $in === '') { + return; + } + + $key = $in.':'.$name; + if (isset($parameterIndex[$key])) { + $parameters[$parameterIndex[$key]] = $parameter; + + return; + } + + $parameterIndex[$key] = count($parameters); + $parameters[] = $parameter; + }; // Add path parameters preg_match_all('/\{([^}?]+)\??}/', $route->uri(), $matches); foreach ($matches[1] as $param) { - $parameters[] = [ + $addParameter([ 'name' => $param, 'in' => 'path', 'required' => true, 'schema' => ['type' => 'string'], - ]; + ]); } // Add parameters from ApiParameter attributes @@ -532,12 +566,12 @@ protected function buildParameters(Route $route, ?object $controller, string $ac foreach ($paramAttrs as $attr) { $param = $attr->newInstance(); - $parameters[] = $param->toOpenApi(); + $addParameter($param->toOpenApi()); } } } - return $parameters; + return array_values($parameters); } /** @@ -578,15 +612,23 @@ protected function buildResponseSchema(ApiResponse $response): array 'description' => $response->getDescription(), ]; - if ($response->resource !== null && class_exists($response->resource)) { + $schema = null; + + if (is_array($response->schema) && ! empty($response->schema)) { + $schema = $response->schema; + } elseif ($response->resource !== null && class_exists($response->resource)) { $schema = $this->extractResourceSchema($response->resource); if ($response->paginated) { $schema = $this->wrapPaginatedSchema($schema); } + } + + if ($schema !== null) { + $contentType = $response->contentType ?: 'application/json'; $result['content'] = [ - 'application/json' => [ + $contentType => [ 'schema' => $schema, ], ]; @@ -614,10 +656,177 @@ protected function extractResourceSchema(string $resourceClass): array return ['type' => 'object']; } - // For now, return a generic object schema - // A more sophisticated implementation would analyze the resource's toArray method + try { + $resource = new $resourceClass(new \stdClass); + $data = $resource->toArray(request()); + + if (is_array($data)) { + return $this->inferArraySchema($data); + } + } catch (\Throwable) { + // Fall back to a generic object schema when the resource cannot + // be instantiated safely in the current context. + } + + return [ + 'type' => 'object', + 'additionalProperties' => true, + ]; + } + + /** + * Infer an OpenAPI schema from a PHP array. + */ + protected function inferArraySchema(array $value): array + { + if (array_is_list($value)) { + $itemSchema = ['type' => 'object']; + + foreach ($value as $item) { + if ($item === null) { + continue; + } + + $itemSchema = $this->inferValueSchema($item); + break; + } + + return [ + 'type' => 'array', + 'items' => $itemSchema, + ]; + } + + $properties = []; + foreach ($value as $key => $item) { + $properties[(string) $key] = $this->inferValueSchema($item, (string) $key); + } + + return [ + 'type' => 'object', + 'properties' => $properties, + 'additionalProperties' => true, + ]; + } + + /** + * Infer an OpenAPI schema node from a PHP value. + */ + protected function inferValueSchema(mixed $value, ?string $key = null): array + { + if ($value === null) { + return $this->inferNullableSchema($key); + } + + if (is_bool($value)) { + return ['type' => 'boolean']; + } + + if (is_int($value)) { + return ['type' => 'integer']; + } + + if (is_float($value)) { + return ['type' => 'number']; + } + + if (is_string($value)) { + return $this->inferStringSchema($value, $key); + } + + if (is_array($value)) { + return $this->inferArraySchema($value); + } + + if (is_object($value)) { + return $this->inferObjectSchema($value); + } + + return []; + } + + /** + * Infer a schema for a null value using the field name as a hint. + */ + protected function inferNullableSchema(?string $key): array + { + if ($key === null) { + return ['nullable' => true]; + } + + $normalized = strtolower($key); + + return match (true) { + $normalized === 'id', + str_ends_with($normalized, '_id'), + str_ends_with($normalized, 'count'), + str_ends_with($normalized, 'total'), + str_ends_with($normalized, 'page'), + str_ends_with($normalized, 'limit'), + str_ends_with($normalized, 'offset'), + str_ends_with($normalized, 'size'), + str_ends_with($normalized, 'quantity'), + str_ends_with($normalized, 'rank'), + str_ends_with($normalized, 'score') => ['type' => 'integer', 'nullable' => true], + str_starts_with($normalized, 'is_'), + str_starts_with($normalized, 'has_'), + str_starts_with($normalized, 'can_'), + str_starts_with($normalized, 'should_'), + str_starts_with($normalized, 'enabled'), + str_starts_with($normalized, 'active') => ['type' => 'boolean', 'nullable' => true], + str_ends_with($normalized, '_at'), + str_ends_with($normalized, '_on'), + str_contains($normalized, 'date'), + str_contains($normalized, 'time'), + str_contains($normalized, 'timestamp') => ['type' => 'string', 'format' => 'date-time', 'nullable' => true], + str_contains($normalized, 'email') => ['type' => 'string', 'format' => 'email', 'nullable' => true], + str_contains($normalized, 'url'), + str_contains($normalized, 'uri') => ['type' => 'string', 'format' => 'uri', 'nullable' => true], + str_contains($normalized, 'uuid') => ['type' => 'string', 'format' => 'uuid', 'nullable' => true], + str_contains($normalized, 'name'), + str_contains($normalized, 'title'), + str_contains($normalized, 'description'), + str_contains($normalized, 'status'), + str_contains($normalized, 'type'), + str_contains($normalized, 'code'), + str_contains($normalized, 'token'), + str_contains($normalized, 'slug'), + str_contains($normalized, 'key') => ['type' => 'string', 'nullable' => true], + default => ['nullable' => true], + }; + } + + /** + * Infer a schema for a string value using the field name as a hint. + */ + protected function inferStringSchema(string $value, ?string $key): array + { + if ($key !== null) { + $nullable = $this->inferNullableSchema($key); + + if (($nullable['type'] ?? null) === 'string') { + $nullable['nullable'] = false; + return $nullable; + } + } + + return ['type' => 'string']; + } + + /** + * Infer a schema for an object value. + */ + protected function inferObjectSchema(object $value): array + { + $properties = []; + + foreach (get_object_vars($value) as $key => $item) { + $properties[$key] = $this->inferValueSchema($item, (string) $key); + } + return [ 'type' => 'object', + 'properties' => $properties, 'additionalProperties' => true, ]; } @@ -661,8 +870,45 @@ protected function wrapPaginatedSchema(array $itemSchema): array /** * Build request body schema. */ - protected function buildRequestBody(?object $controller, string $action): array + protected function buildRequestBody(Route $route, ?object $controller, string $action): array { + if ($controller instanceof \Core\Api\Controllers\McpApiController && $action === 'callTool') { + return [ + 'required' => true, + 'content' => [ + 'application/json' => [ + 'schema' => [ + 'type' => 'object', + 'properties' => [ + 'server' => [ + 'type' => 'string', + 'maxLength' => 64, + 'description' => 'MCP server identifier.', + ], + 'tool' => [ + 'type' => 'string', + 'maxLength' => 128, + 'description' => 'Tool name to invoke on the selected server.', + ], + 'arguments' => [ + 'type' => 'object', + 'description' => 'Tool arguments passed through to MCP.', + 'additionalProperties' => true, + ], + 'version' => [ + 'type' => 'string', + 'maxLength' => 32, + 'description' => 'Optional tool version to execute. Defaults to the latest supported version.', + ], + ], + 'required' => ['server', 'tool'], + 'additionalProperties' => true, + ], + ], + ], + ]; + } + return [ 'required' => true, 'content' => [ diff --git a/src/php/src/Api/Documentation/Routes/docs.php b/src/php/src/Api/Documentation/Routes/docs.php index 5ff04e2..5531a77 100644 --- a/src/php/src/Api/Documentation/Routes/docs.php +++ b/src/php/src/Api/Documentation/Routes/docs.php @@ -20,6 +20,7 @@ Route::get('/swagger', [DocumentationController::class, 'swagger'])->name('api.docs.swagger'); Route::get('/scalar', [DocumentationController::class, 'scalar'])->name('api.docs.scalar'); Route::get('/redoc', [DocumentationController::class, 'redoc'])->name('api.docs.redoc'); +Route::get('/stoplight', [DocumentationController::class, 'stoplight'])->name('api.docs.stoplight'); // OpenAPI specification routes Route::get('/openapi.json', [DocumentationController::class, 'openApiJson']) diff --git a/src/php/src/Api/Documentation/Views/stoplight.blade.php b/src/php/src/Api/Documentation/Views/stoplight.blade.php new file mode 100644 index 0000000..803eed3 --- /dev/null +++ b/src/php/src/Api/Documentation/Views/stoplight.blade.php @@ -0,0 +1,34 @@ + + + + + + + {{ config('api-docs.info.title', 'API Documentation') }} - Stoplight + + + + + + + + + diff --git a/src/php/src/Api/Documentation/config.php b/src/php/src/Api/Documentation/config.php index 0c43186..378898a 100644 --- a/src/php/src/Api/Documentation/config.php +++ b/src/php/src/Api/Documentation/config.php @@ -268,6 +268,13 @@ 'hide_download_button' => false, 'hide_models' => false, ], + + // Stoplight Elements specific options + 'stoplight' => [ + 'theme' => 'dark', // 'dark' or 'light' + 'layout' => 'sidebar', // 'sidebar' or 'stacked' + 'hide_try_it' => false, + ], ], /* diff --git a/src/php/src/Api/Exceptions/RateLimitExceededException.php b/src/php/src/Api/Exceptions/RateLimitExceededException.php index cad4e41..8eb7401 100644 --- a/src/php/src/Api/Exceptions/RateLimitExceededException.php +++ b/src/php/src/Api/Exceptions/RateLimitExceededException.php @@ -5,7 +5,9 @@ namespace Core\Api\Exceptions; use Core\Api\RateLimit\RateLimitResult; +use Core\Api\Concerns\HasApiResponses; use Illuminate\Http\JsonResponse; +use Illuminate\Http\Request; use Symfony\Component\HttpKernel\Exception\HttpException; /** @@ -15,6 +17,8 @@ */ class RateLimitExceededException extends HttpException { + use HasApiResponses; + public function __construct( protected RateLimitResult $rateLimitResult, string $message = 'Too many requests. Please slow down.', @@ -33,15 +37,22 @@ public function getRateLimitResult(): RateLimitResult /** * Render the exception as a JSON response. */ - public function render(): JsonResponse + public function render(?Request $request = null): JsonResponse { - return response()->json([ - 'error' => 'rate_limit_exceeded', - 'message' => $this->getMessage(), - 'retry_after' => $this->rateLimitResult->retryAfter, - 'limit' => $this->rateLimitResult->limit, - 'resets_at' => $this->rateLimitResult->resetsAt->toIso8601String(), - ], 429, $this->rateLimitResult->headers()); + // Return the rate-limit error response with rate-limit headers attached. + // CORS headers are intentionally omitted here; they are applied by the + // framework's CORS middleware (or PublicApiCors) which handles patterns, + // credentials, and Vary correctly for all responses — including errors. + return $this->errorResponse( + errorCode: 'rate_limit_exceeded', + message: $this->getMessage(), + meta: [ + 'retry_after' => $this->rateLimitResult->retryAfter, + 'limit' => $this->rateLimitResult->limit, + 'resets_at' => $this->rateLimitResult->resetsAt->toIso8601String(), + ], + status: 429, + )->withHeaders($this->rateLimitResult->headers()); } /** diff --git a/src/php/src/Api/Middleware/AuthenticateApiKey.php b/src/php/src/Api/Middleware/AuthenticateApiKey.php index 40b6fe9..ecddfc5 100644 --- a/src/php/src/Api/Middleware/AuthenticateApiKey.php +++ b/src/php/src/Api/Middleware/AuthenticateApiKey.php @@ -6,6 +6,7 @@ use Core\Api\Models\ApiKey; use Core\Api\Services\IpRestrictionService; +use Core\Api\Concerns\HasApiResponses; use Closure; use Illuminate\Http\Request; use Symfony\Component\HttpFoundation\Response; @@ -24,6 +25,8 @@ */ class AuthenticateApiKey { + use HasApiResponses; + public function handle(Request $request, Closure $next, ?string $scope = null): Response { $token = $request->bearerToken(); @@ -113,14 +116,15 @@ protected function authenticateSanctum( } /** - * Return 401 Unauthorized response. + * Return 401 Unauthorised response. */ protected function unauthorized(string $message): Response { - return response()->json([ - 'error' => 'unauthorized', - 'message' => $message, - ], 401); + return $this->errorResponse( + errorCode: 'unauthorized', + message: $message, + status: 401, + ); } /** @@ -128,9 +132,6 @@ protected function unauthorized(string $message): Response */ protected function forbidden(string $message): Response { - return response()->json([ - 'error' => 'forbidden', - 'message' => $message, - ], 403); + return $this->forbiddenResponse($message, status: 403); } } diff --git a/src/php/src/Api/Middleware/CheckApiScope.php b/src/php/src/Api/Middleware/CheckApiScope.php index 32aeec0..614bb9c 100644 --- a/src/php/src/Api/Middleware/CheckApiScope.php +++ b/src/php/src/Api/Middleware/CheckApiScope.php @@ -5,6 +5,7 @@ namespace Core\Api\Middleware; use Core\Api\Models\ApiKey; +use Core\Api\Concerns\HasApiResponses; use Closure; use Illuminate\Http\Request; use Symfony\Component\HttpFoundation\Response; @@ -25,6 +26,8 @@ */ class CheckApiScope { + use HasApiResponses; + public function handle(Request $request, Closure $next, string ...$scopes): Response { $apiKey = $request->attributes->get('api_key'); @@ -38,12 +41,13 @@ public function handle(Request $request, Closure $next, string ...$scopes): Resp // Check all required scopes foreach ($scopes as $scope) { if (! $apiKey->hasScope($scope)) { - return response()->json([ - 'error' => 'forbidden', - 'message' => "API key missing required scope: {$scope}", - 'required_scopes' => $scopes, - 'key_scopes' => $apiKey->scopes, - ], 403); + return $this->forbiddenResponse( + message: "API key missing required scope: {$scope}", + meta: [ + 'required_scopes' => $scopes, + 'key_scopes' => $apiKey->scopes, + ], + ); } } diff --git a/src/php/src/Api/Middleware/EnforceApiScope.php b/src/php/src/Api/Middleware/EnforceApiScope.php index 958e3fa..9b83495 100644 --- a/src/php/src/Api/Middleware/EnforceApiScope.php +++ b/src/php/src/Api/Middleware/EnforceApiScope.php @@ -6,6 +6,7 @@ use Closure; use Core\Api\Models\ApiKey; +use Core\Api\Concerns\HasApiResponses; use Illuminate\Http\Request; use Symfony\Component\HttpFoundation\Response; @@ -25,6 +26,8 @@ */ class EnforceApiScope { + use HasApiResponses; + /** * HTTP method to required scope mapping. */ @@ -52,12 +55,13 @@ public function handle(Request $request, Closure $next): Response $requiredScope = self::METHOD_SCOPES[$method] ?? ApiKey::SCOPE_READ; if (! $apiKey->hasScope($requiredScope)) { - return response()->json([ - 'error' => 'forbidden', - 'message' => "API key missing required scope: {$requiredScope}", - 'detail' => "{$method} requests require '{$requiredScope}' scope", - 'key_scopes' => $apiKey->scopes, - ], 403); + return $this->forbiddenResponse( + message: "API key missing required scope: {$requiredScope}", + meta: [ + 'detail' => "{$method} requests require '{$requiredScope}' scope", + 'key_scopes' => $apiKey->scopes, + ], + ); } return $next($request); diff --git a/src/php/src/Api/Resources/ErrorResource.php b/src/php/src/Api/Resources/ErrorResource.php index ca62eca..a39ba38 100644 --- a/src/php/src/Api/Resources/ErrorResource.php +++ b/src/php/src/Api/Resources/ErrorResource.php @@ -40,7 +40,7 @@ public static function make(...$args): static /** * Common error factory methods. */ - public static function unauthorized(string $message = 'Unauthorized'): static + public static function unauthorized(string $message = 'Unauthorised'): static { return new static('unauthorized', $message); } diff --git a/src/php/src/Api/Routes/api.php b/src/php/src/Api/Routes/api.php index cec4478..26f6a69 100644 --- a/src/php/src/Api/Routes/api.php +++ b/src/php/src/Api/Routes/api.php @@ -2,7 +2,12 @@ declare(strict_types=1); +use Core\Api\Controllers\Api\UnifiedPixelController; +use Core\Api\Controllers\Api\EntitlementApiController; +use Core\Api\Controllers\Api\SeoReportController; +use Core\Api\Controllers\Api\WebhookSecretController; use Core\Api\Controllers\McpApiController; +use Core\Api\Middleware\PublicApiCors; use Core\Mcp\Middleware\McpApiKeyAuth; use Illuminate\Support\Facades\Route; @@ -13,11 +18,81 @@ | | Core API routes for cross-cutting concerns. | -| TODO: SeoReportController, UnifiedPixelController, EntitlementApiController -| are planned but not yet implemented. Re-add routes when controllers exist. +| SEO, pixel tracking, entitlements, and MCP bridge endpoints. | */ +// ───────────────────────────────────────────────────────────────────────────── +// Unified Pixel (public tracking) +// ───────────────────────────────────────────────────────────────────────────── + +Route::middleware([PublicApiCors::class, 'api.rate']) + ->prefix('pixel') + ->name('api.pixel.') + ->group(function () { + Route::match(['GET', 'POST', 'OPTIONS'], '/{pixelKey}', [UnifiedPixelController::class, 'track']) + ->name('track'); + }); + +// ───────────────────────────────────────────────────────────────────────────── +// SEO analysis (authenticated) +// ───────────────────────────────────────────────────────────────────────────── + +Route::middleware(['auth.api', 'api.scope.enforce']) + ->prefix('seo') + ->name('api.seo.') + ->group(function () { + Route::get('/report', [SeoReportController::class, 'show']) + ->name('report'); + }); + +// ───────────────────────────────────────────────────────────────────────────── +// Entitlements (authenticated) +// ───────────────────────────────────────────────────────────────────────────── + +Route::middleware(['auth.api', 'api.scope.enforce']) + ->prefix('entitlements') + ->name('api.entitlements.') + ->group(function () { + Route::get('/', [EntitlementApiController::class, 'show']) + ->name('show'); + }); + +// ───────────────────────────────────────────────────────────────────────────── +// Webhook secret rotation (authenticated) +// ───────────────────────────────────────────────────────────────────────────── + +Route::middleware(['auth.api', 'api.scope.enforce']) + ->prefix('webhooks') + ->name('api.webhooks.') + ->group(function () { + Route::prefix('social/{uuid}/secret') + ->name('social.') + ->group(function () { + Route::post('/rotate', [WebhookSecretController::class, 'rotateSocialSecret']) + ->name('rotate-secret'); + Route::get('/', [WebhookSecretController::class, 'socialSecretStatus']) + ->name('status'); + Route::delete('/previous', [WebhookSecretController::class, 'invalidateSocialPreviousSecret']) + ->name('invalidate-previous'); + Route::patch('/grace-period', [WebhookSecretController::class, 'updateSocialGracePeriod']) + ->name('grace-period'); + }); + + Route::prefix('content/{uuid}/secret') + ->name('content.') + ->group(function () { + Route::post('/rotate', [WebhookSecretController::class, 'rotateContentSecret']) + ->name('rotate-secret'); + Route::get('/', [WebhookSecretController::class, 'contentSecretStatus']) + ->name('status'); + Route::delete('/previous', [WebhookSecretController::class, 'invalidateContentPreviousSecret']) + ->name('invalidate-previous'); + Route::patch('/grace-period', [WebhookSecretController::class, 'updateContentGracePeriod']) + ->name('grace-period'); + }); + }); + // ───────────────────────────────────────────────────────────────────────────── // MCP HTTP Bridge (API key auth) // ───────────────────────────────────────────────────────────────────────────── @@ -34,6 +109,8 @@ ->name('servers.show'); Route::get('/servers/{id}/tools', [McpApiController::class, 'tools']) ->name('servers.tools'); + Route::get('/servers/{id}/resources', [McpApiController::class, 'resources']) + ->name('servers.resources'); // Tool version history (read) Route::get('/servers/{server}/tools/{tool}/versions', [McpApiController::class, 'toolVersions']) diff --git a/src/php/src/Api/Services/ApiUsageService.php b/src/php/src/Api/Services/ApiUsageService.php index 204f444..f5d7445 100644 --- a/src/php/src/Api/Services/ApiUsageService.php +++ b/src/php/src/Api/Services/ApiUsageService.php @@ -2,11 +2,12 @@ declare(strict_types=1); -namespace Mod\Api\Services; +namespace Core\Api\Services; use Carbon\Carbon; -use Mod\Api\Models\ApiUsage; -use Mod\Api\Models\ApiUsageDaily; +use Core\Api\Models\ApiKey; +use Core\Api\Models\ApiUsage; +use Core\Api\Models\ApiUsageDaily; /** * API Usage Service - tracks and reports API usage metrics. @@ -282,7 +283,7 @@ public function getKeyComparison( // Fetch API keys separately to avoid broken eager loading with aggregation $apiKeyIds = $aggregated->pluck('api_key_id')->filter()->unique()->all(); - $apiKeys = \Mod\Api\Models\ApiKey::whereIn('id', $apiKeyIds) + $apiKeys = ApiKey::whereIn('id', $apiKeyIds) ->select('id', 'name', 'prefix') ->get() ->keyBy('id'); diff --git a/src/php/src/Api/Services/SeoReportService.php b/src/php/src/Api/Services/SeoReportService.php new file mode 100644 index 0000000..70cd0e2 --- /dev/null +++ b/src/php/src/Api/Services/SeoReportService.php @@ -0,0 +1,534 @@ +validateUrlForSsrf($url); + + try { + $response = Http::withHeaders([ + 'User-Agent' => config('app.name', 'Core API').' SEO Reporter/1.0', + 'Accept' => 'text/html,application/xhtml+xml', + ]) + ->timeout((int) config('api.seo.timeout', 10)) + ->withoutRedirecting() + ->get($url) + ->throw(); + } catch (RuntimeException $exception) { + throw $exception; + } catch (Throwable $exception) { + throw new RuntimeException('Unable to fetch the requested URL.', 0, $exception); + } + + $html = (string) $response->body(); + $xpath = $this->loadXPath($html); + + $title = $this->extractSingleText($xpath, '//title'); + $description = $this->extractMetaContent($xpath, 'description'); + $canonical = $this->extractLinkHref($xpath, 'canonical'); + $robots = $this->extractMetaContent($xpath, 'robots'); + $language = $this->extractHtmlAttribute($xpath, 'lang'); + $charset = $this->extractCharset($xpath); + + $openGraph = [ + 'title' => $this->extractMetaContent($xpath, 'og:title', 'property'), + 'description' => $this->extractMetaContent($xpath, 'og:description', 'property'), + 'image' => $this->extractMetaContent($xpath, 'og:image', 'property'), + 'type' => $this->extractMetaContent($xpath, 'og:type', 'property'), + 'site_name' => $this->extractMetaContent($xpath, 'og:site_name', 'property'), + ]; + + $twitterCard = [ + 'card' => $this->extractMetaContent($xpath, 'twitter:card', 'name'), + 'title' => $this->extractMetaContent($xpath, 'twitter:title', 'name'), + 'description' => $this->extractMetaContent($xpath, 'twitter:description', 'name'), + 'image' => $this->extractMetaContent($xpath, 'twitter:image', 'name'), + ]; + + $headings = $this->countHeadings($xpath); + $issues = $this->buildIssues($title, $description, $canonical, $robots, $openGraph, $headings); + + return [ + 'url' => $url, + 'status_code' => $response->status(), + 'content_type' => $response->header('Content-Type'), + 'score' => $this->calculateScore($issues), + 'summary' => [ + 'title' => $title, + 'description' => $description, + 'canonical' => $canonical, + 'robots' => $robots, + 'language' => $language, + 'charset' => $charset, + ], + 'open_graph' => $openGraph, + 'twitter' => $twitterCard, + 'headings' => $headings, + 'issues' => $issues, + 'recommendations' => $this->buildRecommendations($issues), + ]; + } + + /** + * Load an HTML document into an XPath query object. + */ + protected function loadXPath(string $html): DOMXPath + { + $previous = libxml_use_internal_errors(true); + + $document = new DOMDocument(); + $document->loadHTML($html, LIBXML_NOERROR | LIBXML_NOWARNING); + + libxml_clear_errors(); + libxml_use_internal_errors($previous); + + return new DOMXPath($document); + } + + /** + * Extract the first text node matched by an XPath query. + */ + protected function extractSingleText(DOMXPath $xpath, string $query): ?string + { + $nodes = $xpath->query($query); + + if (! $nodes || $nodes->length === 0) { + return null; + } + + $node = $nodes->item(0); + + if (! $node) { + return null; + } + + $value = trim($node->textContent ?? ''); + + return $value !== '' ? $value : null; + } + + /** + * Extract a meta tag content value. + */ + protected function extractMetaContent(DOMXPath $xpath, string $name, string $attribute = 'name'): ?string + { + $query = sprintf('//meta[@%s=%s]/@content', $attribute, $this->quoteForXPath($name)); + $nodes = $xpath->query($query); + + if (! $nodes || $nodes->length === 0) { + return null; + } + + $node = $nodes->item(0); + + if (! $node) { + return null; + } + + $value = trim($node->textContent ?? ''); + + return $value !== '' ? $value : null; + } + + /** + * Extract a link href value. + */ + protected function extractLinkHref(DOMXPath $xpath, string $rel): ?string + { + $query = sprintf('//link[@rel=%s]/@href', $this->quoteForXPath($rel)); + $nodes = $xpath->query($query); + + if (! $nodes || $nodes->length === 0) { + return null; + } + + $node = $nodes->item(0); + + if (! $node) { + return null; + } + + $value = trim($node->textContent ?? ''); + + return $value !== '' ? $value : null; + } + + /** + * Extract the HTML lang attribute. + */ + protected function extractHtmlAttribute(DOMXPath $xpath, string $attribute): ?string + { + $nodes = $xpath->query(sprintf('//html/@%s', $attribute)); + + if (! $nodes || $nodes->length === 0) { + return null; + } + + $node = $nodes->item(0); + + if (! $node) { + return null; + } + + $value = trim($node->textContent ?? ''); + + return $value !== '' ? $value : null; + } + + /** + * Extract a charset declaration. + */ + protected function extractCharset(DOMXPath $xpath): ?string + { + $nodes = $xpath->query('//meta[@charset]/@charset'); + + if ($nodes && $nodes->length > 0) { + $node = $nodes->item(0); + + if ($node) { + $value = trim($node->textContent ?? ''); + + if ($value !== '') { + return $value; + } + } + } + + // The http-equiv Content-Type meta returns a full value such as + // "text/html; charset=utf-8". Extract only the charset token so that + // callers receive a bare encoding label (e.g. "utf-8"), not the whole + // content-type string. + $contentType = $this->extractMetaContent($xpath, 'content-type', 'http-equiv'); + if ($contentType !== null) { + if (preg_match('/charset\s*=\s*["\']?([^\s;"\']+)/i', $contentType, $matches)) { + return $matches[1]; + } + } + + return null; + } + + /** + * Count headings by level. + * + * @return array + */ + protected function countHeadings(DOMXPath $xpath): array + { + $counts = []; + + for ($level = 1; $level <= 6; $level++) { + $nodes = $xpath->query('//h'.$level); + $counts['h'.$level] = $nodes ? $nodes->length : 0; + } + + return $counts; + } + + /** + * Build issue list from the extracted SEO data. + * + * @return array> + */ + protected function buildIssues( + ?string $title, + ?string $description, + ?string $canonical, + ?string $robots, + array $openGraph, + array $headings + ): array { + $issues = []; + + if ($title === null) { + $issues[] = $this->issue('missing_title', 'No tag was found.', 'high'); + } elseif (Str::length($title) < 10) { + $issues[] = $this->issue('title_too_short', 'The page title is shorter than 10 characters.', 'medium'); + } elseif (Str::length($title) > 60) { + $issues[] = $this->issue('title_too_long', 'The page title is longer than 60 characters.', 'medium'); + } + + if ($description === null) { + $issues[] = $this->issue('missing_description', 'No meta description was found.', 'high'); + } + + if ($canonical === null) { + $issues[] = $this->issue('missing_canonical', 'No canonical URL was found.', 'medium'); + } + + if (($headings['h1'] ?? 0) === 0) { + $issues[] = $this->issue('missing_h1', 'The page does not contain an H1 heading.', 'high'); + } elseif (($headings['h1'] ?? 0) > 1) { + $issues[] = $this->issue('multiple_h1', 'The page contains multiple H1 headings.', 'medium'); + } + + if (($openGraph['title'] ?? null) === null) { + $issues[] = $this->issue('missing_og_title', 'No Open Graph title was found.', 'low'); + } + + if (($openGraph['description'] ?? null) === null) { + $issues[] = $this->issue('missing_og_description', 'No Open Graph description was found.', 'low'); + } + + if ($robots !== null && Str::contains(Str::lower($robots), ['noindex', 'nofollow'])) { + $issues[] = $this->issue('robots_restricted', 'Robots directives block indexing or following links.', 'high'); + } + + return $issues; + } + + /** + * Convert a list of issues to a report score. + */ + protected function calculateScore(array $issues): int + { + $penalties = [ + 'missing_title' => 20, + 'title_too_short' => 5, + 'title_too_long' => 5, + 'missing_description' => 15, + 'missing_canonical' => 10, + 'missing_h1' => 15, + 'multiple_h1' => 5, + 'missing_og_title' => 5, + 'missing_og_description' => 5, + 'robots_restricted' => 20, + ]; + + $score = 100; + + foreach ($issues as $issue) { + $score -= $penalties[$issue['code']] ?? 0; + } + + return max(0, $score); + } + + /** + * Build recommendations from issues. + * + * @return array<int, string> + */ + protected function buildRecommendations(array $issues): array + { + $recommendations = []; + + foreach ($issues as $issue) { + $recommendations[] = match ($issue['code']) { + 'missing_title' => 'Add a concise page title that describes the page content.', + 'title_too_short' => 'Expand the page title so it is more descriptive.', + 'title_too_long' => 'Shorten the page title to keep it under 60 characters.', + 'missing_description' => 'Add a meta description to improve search snippets.', + 'missing_canonical' => 'Add a canonical URL to prevent duplicate content issues.', + 'missing_h1' => 'Add a single, descriptive H1 heading.', + 'multiple_h1' => 'Reduce the page to a single primary H1 heading.', + 'missing_og_title' => 'Add an Open Graph title for better social sharing.', + 'missing_og_description' => 'Add an Open Graph description for better social sharing.', + 'robots_restricted' => 'Remove noindex or nofollow directives if the page should be indexed.', + default => $issue['message'], + }; + } + + return array_values(array_unique($recommendations)); + } + + /** + * Build an issue record. + * + * @return array{code: string, message: string, severity: string} + */ + protected function issue(string $code, string $message, string $severity): array + { + return [ + 'code' => $code, + 'message' => $message, + 'severity' => $severity, + ]; + } + + /** + * Validate that a URL is safe to fetch and does not target internal/private + * network resources (SSRF protection). + * + * Blocks: + * - Non-HTTP/HTTPS schemes + * - Loopback addresses (127.0.0.0/8, ::1) + * - RFC-1918 private ranges (10/8, 172.16/12, 192.168/16) + * - Link-local ranges (169.254.0.0/16, fe80::/10) + * - IPv6 ULA (fc00::/7) + * + * @throws RuntimeException when the URL fails SSRF validation. + */ + protected function validateUrlForSsrf(string $url): void + { + $parsed = parse_url($url); + + if ($parsed === false || empty($parsed['scheme']) || empty($parsed['host'])) { + throw new RuntimeException('The supplied URL is not valid.'); + } + + if (! in_array(strtolower($parsed['scheme']), ['http', 'https'], true)) { + throw new RuntimeException('Only HTTP and HTTPS URLs are permitted.'); + } + + $host = $parsed['host']; + + // If the host is an IP literal (IPv4 or bracketed IPv6), validate it + // directly. dns_get_record returns nothing for IP literals and + // gethostbyname returns the same value, so both would silently skip + // the private-range check without this explicit guard. + $normalised = ltrim(rtrim($host, ']'), '['); // strip IPv6 brackets + if (filter_var($normalised, FILTER_VALIDATE_IP) !== false) { + if ($this->isPrivateIp($normalised)) { + throw new RuntimeException('The supplied URL resolves to a private or reserved address.'); + } + + // Valid public IP literal — no DNS lookup required. + return; + } + + $records = dns_get_record($host, DNS_A | DNS_AAAA) ?: []; + + // Fall back to gethostbyname for hosts not returned by dns_get_record. + if (empty($records)) { + $resolved = gethostbyname($host); + if ($resolved !== $host) { + $records[] = ['ip' => $resolved]; + } + } + + foreach ($records as $record) { + $ip = $record['ip'] ?? $record['ipv6'] ?? null; + if ($ip === null) { + continue; + } + if ($this->isPrivateIp($ip)) { + throw new RuntimeException('The supplied URL resolves to a private or reserved address.'); + } + } + } + + /** + * Return true when an IP address falls within a private, loopback, or + * link-local range. + */ + protected function isPrivateIp(string $ip): bool + { + // inet_pton returns false for invalid addresses. + $packed = inet_pton($ip); + if ($packed === false) { + return true; // Treat unresolvable as unsafe. + } + + if (strlen($packed) === 4) { + return $this->isPrivateIpv4($ip); + } + + // IPv6 checks. + + // ::ffff:0:0/96 — IPv4-mapped addresses (e.g. ::ffff:127.0.0.1). + // The first 10 bytes are 0x00, bytes 10-11 are 0xff 0xff, then 4 + // bytes of IPv4. Evaluate the embedded IPv4 address against the + // standard private ranges. + if (str_repeat("\x00", 10) . "\xff\xff" === substr($packed, 0, 12)) { + $ipv4 = inet_ntop(substr($packed, 12, 4)); + if ($ipv4 !== false && $this->isPrivateIpv4($ipv4)) { + return true; + } + } + + // Loopback (::1). + if ($ip === '::1') { + return true; + } + $prefix2 = strtolower(substr(bin2hex($packed), 0, 2)); + // fe80::/10 — first byte 0xfe, second byte 0x80–0xbf + if ($prefix2 === 'fe') { + $secondNibble = hexdec(substr(bin2hex($packed), 2, 1)); + if ($secondNibble >= 8 && $secondNibble <= 11) { + return true; + } + } + // fc00::/7 — first byte 0xfc or 0xfd + if (in_array($prefix2, ['fc', 'fd'], true)) { + return true; + } + + return false; + } + + /** + * Return true when an IPv4 address string falls within a private, + * loopback, link-local, or reserved range. + * + * Handles 0.0.0.0/8 (RFC 1122 "this network"), 127/8 (loopback), + * 10/8, 172.16/12, 192.168/16 (RFC 1918), and 169.254/16 (link-local). + */ + protected function isPrivateIpv4(string $ip): bool + { + $long = ip2long($ip); + if ($long === false) { + return true; // Treat unparsable as unsafe. + } + + $privateRanges = [ + ['start' => ip2long('0.0.0.0'), 'end' => ip2long('0.255.255.255')], // 0.0.0.0/8 (RFC 1122) + ['start' => ip2long('127.0.0.0'), 'end' => ip2long('127.255.255.255')], // loopback + ['start' => ip2long('10.0.0.0'), 'end' => ip2long('10.255.255.255')], // RFC-1918 + ['start' => ip2long('172.16.0.0'), 'end' => ip2long('172.31.255.255')], // RFC-1918 + ['start' => ip2long('192.168.0.0'), 'end' => ip2long('192.168.255.255')], // RFC-1918 + ['start' => ip2long('169.254.0.0'), 'end' => ip2long('169.254.255.255')], // link-local + ]; + + foreach ($privateRanges as $range) { + if ($long >= $range['start'] && $long <= $range['end']) { + return true; + } + } + + return false; + } + + /** + * Quote a literal for XPath queries. + */ + protected function quoteForXPath(string $value): string + { + if (! str_contains($value, "'")) { + return "'{$value}'"; + } + + if (! str_contains($value, '"')) { + return '"'.$value.'"'; + } + + $parts = array_map( + fn (string $part) => "'{$part}'", + explode("'", $value) + ); + + return 'concat('.implode(", \"'\", ", $parts).')'; + } +} diff --git a/src/php/src/Api/Tests/Feature/ApiUsageTest.php b/src/php/src/Api/Tests/Feature/ApiUsageTest.php index 20c3f0d..74ae92b 100644 --- a/src/php/src/Api/Tests/Feature/ApiUsageTest.php +++ b/src/php/src/Api/Tests/Feature/ApiUsageTest.php @@ -2,12 +2,12 @@ declare(strict_types=1); -use Mod\Api\Models\ApiKey; -use Mod\Api\Models\ApiUsage; -use Mod\Api\Models\ApiUsageDaily; -use Mod\Api\Services\ApiUsageService; -use Mod\Tenant\Models\User; -use Mod\Tenant\Models\Workspace; +use Core\Api\Models\ApiKey; +use Core\Api\Models\ApiUsage; +use Core\Api\Models\ApiUsageDaily; +use Core\Api\Services\ApiUsageService; +use Core\Tenant\Models\User; +use Core\Tenant\Models\Workspace; uses(\Illuminate\Foundation\Testing\RefreshDatabase::class); diff --git a/src/php/src/Api/Tests/Feature/DocumentationStoplightTest.php b/src/php/src/Api/Tests/Feature/DocumentationStoplightTest.php new file mode 100644 index 0000000..ace25a8 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/DocumentationStoplightTest.php @@ -0,0 +1,21 @@ +<?php + +declare(strict_types=1); + +it('renders Stoplight Elements when selected as the default documentation ui', function () { + config(['api-docs.ui.default' => 'stoplight']); + + $response = $this->get('/api/docs'); + + $response->assertOk(); + $response->assertSee('elements-api', false); + $response->assertSee('@stoplight/elements', false); +}); + +it('renders the dedicated Stoplight documentation route', function () { + $response = $this->get('/api/docs/stoplight'); + + $response->assertOk(); + $response->assertSee('elements-api', false); + $response->assertSee('@stoplight/elements', false); +}); diff --git a/src/php/src/Api/Tests/Feature/EntitlementsEndpointTest.php b/src/php/src/Api/Tests/Feature/EntitlementsEndpointTest.php new file mode 100644 index 0000000..ed8b87f --- /dev/null +++ b/src/php/src/Api/Tests/Feature/EntitlementsEndpointTest.php @@ -0,0 +1,54 @@ +<?php + +declare(strict_types=1); + +use Core\Api\Models\ApiKey; +use Core\Api\Services\ApiUsageService; +use Core\Tenant\Models\User; +use Core\Tenant\Models\Workspace; + +uses(\Illuminate\Foundation\Testing\RefreshDatabase::class); + +beforeEach(function () { + $this->user = User::factory()->create(); + $this->workspace = Workspace::factory()->create(); + $this->workspace->users()->attach($this->user->id, [ + 'role' => 'owner', + 'is_default' => true, + ]); + + $result = ApiKey::generate( + $this->workspace->id, + $this->user->id, + 'Entitlements Key', + [ApiKey::SCOPE_READ] + ); + + $this->plainKey = $result['plain_key']; + $this->apiKey = $result['api_key']; +}); + +it('returns entitlement limits and usage for the current workspace', function () { + app(ApiUsageService::class)->record( + apiKeyId: $this->apiKey->id, + workspaceId: $this->workspace->id, + endpoint: '/api/entitlements', + method: 'GET', + statusCode: 200, + responseTimeMs: 42, + ipAddress: '127.0.0.1', + userAgent: 'Pest' + ); + + $response = $this->getJson('/api/entitlements', [ + 'Authorization' => "Bearer {$this->plainKey}", + ]); + + $response->assertOk(); + $response->assertJsonPath('workspace_id', $this->workspace->id); + $response->assertJsonPath('authentication.type', 'api_key'); + $response->assertJsonPath('limits.api_keys.maximum', config('api.keys.max_per_workspace')); + $response->assertJsonPath('limits.api_keys.active', 1); + $response->assertJsonPath('usage.totals.requests', 1); + $response->assertJsonPath('features.mcp', true); +}); diff --git a/src/php/src/Api/Tests/Feature/McpApiControllerTest.php b/src/php/src/Api/Tests/Feature/McpApiControllerTest.php new file mode 100644 index 0000000..88c6a23 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/McpApiControllerTest.php @@ -0,0 +1,43 @@ +<?php + +declare(strict_types=1); + +use Core\Api\Controllers\McpApiController; + +it('includes the requested tool version in the MCP JSON-RPC payload', function () { + $controller = new class extends McpApiController + { + public function payload(string $tool, array $arguments, ?string $version = null): array + { + return $this->buildToolCallRequest($tool, $arguments, $version); + } + }; + + $payload = $controller->payload('search', ['query' => 'status'], '1.2.3'); + + expect($payload['jsonrpc'])->toBe('2.0'); + expect($payload['method'])->toBe('tools/call'); + expect($payload['params'])->toMatchArray([ + 'name' => 'search', + 'arguments' => ['query' => 'status'], + 'version' => '1.2.3', + ]); +}); + +it('omits the version field when one is not requested', function () { + $controller = new class extends McpApiController + { + public function payload(string $tool, array $arguments, ?string $version = null): array + { + return $this->buildToolCallRequest($tool, $arguments, $version); + } + }; + + $payload = $controller->payload('search', ['query' => 'status']); + + expect($payload['params'])->toMatchArray([ + 'name' => 'search', + 'arguments' => ['query' => 'status'], + ]); + expect($payload['params'])->not->toHaveKey('version'); +}); diff --git a/src/php/src/Api/Tests/Feature/McpResourceTest.php b/src/php/src/Api/Tests/Feature/McpResourceTest.php new file mode 100644 index 0000000..9be5108 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/McpResourceTest.php @@ -0,0 +1,102 @@ +<?php + +declare(strict_types=1); + +use Illuminate\Support\Facades\Cache; +use Mod\Api\Models\ApiKey; +use Mod\Tenant\Models\User; +use Mod\Tenant\Models\Workspace; + +uses(\Illuminate\Foundation\Testing\RefreshDatabase::class); + +beforeEach(function () { + Cache::flush(); + + $this->user = User::factory()->create(); + $this->workspace = Workspace::factory()->create(); + $this->workspace->users()->attach($this->user->id, [ + 'role' => 'owner', + 'is_default' => true, + ]); + + $result = ApiKey::generate( + $this->workspace->id, + $this->user->id, + 'MCP Resource Key', + [ApiKey::SCOPE_READ, ApiKey::SCOPE_WRITE] + ); + + $this->plainKey = $result['plain_key']; + + $this->serverId = 'test-resource-server'; + $this->serverDir = resource_path('mcp/servers'); + $this->serverFile = $this->serverDir.'/'.$this->serverId.'.yaml'; + + if (! is_dir($this->serverDir)) { + mkdir($this->serverDir, 0777, true); + } + + file_put_contents($this->serverFile, <<<YAML +id: test-resource-server +name: Test Resource Server +status: available +resources: + - uri: test-resource-server://documents/welcome + path: documents/welcome + name: welcome + content: + message: Hello from the MCP resource bridge + version: 1 +YAML); +}); + +afterEach(function () { + Cache::flush(); + + if (isset($this->serverFile) && is_file($this->serverFile)) { + unlink($this->serverFile); + } + + if (isset($this->serverDir) && is_dir($this->serverDir)) { + @rmdir($this->serverDir); + } + + $mcpDir = dirname($this->serverDir ?? ''); + if (is_dir($mcpDir)) { + @rmdir($mcpDir); + } +}); + +it('reads a resource from the server definition', function () { + $encodedUri = rawurlencode('test-resource-server://documents/welcome'); + + $response = $this->getJson("/api/mcp/resources/{$encodedUri}", [ + 'Authorization' => "Bearer {$this->plainKey}", + ]); + + $response->assertOk(); + $response->assertJson([ + 'uri' => 'test-resource-server://documents/welcome', + 'server' => 'test-resource-server', + 'resource' => 'documents/welcome', + ]); + + expect($response->json('content'))->toBe([ + 'message' => 'Hello from the MCP resource bridge', + 'version' => 1, + ]); +}); + +it('lists resources for a server', function () { + $response = $this->getJson('/api/mcp/servers/test-resource-server/resources', [ + 'Authorization' => "Bearer {$this->plainKey}", + ]); + + $response->assertOk(); + $response->assertJsonPath('server', 'test-resource-server'); + $response->assertJsonPath('count', 1); + $response->assertJsonPath('resources.0.uri', 'test-resource-server://documents/welcome'); + $response->assertJsonPath('resources.0.path', 'documents/welcome'); + $response->assertJsonPath('resources.0.name', 'welcome'); + $response->assertJsonMissingPath('resources.0.content'); +}); diff --git a/src/php/src/Api/Tests/Feature/McpServerDetailTest.php b/src/php/src/Api/Tests/Feature/McpServerDetailTest.php new file mode 100644 index 0000000..7fa3114 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/McpServerDetailTest.php @@ -0,0 +1,116 @@ +<?php + +declare(strict_types=1); + +use Core\Mod\Mcp\Services\ToolVersionService; +use Illuminate\Support\Facades\Cache; +use Mod\Api\Models\ApiKey; +use Mod\Tenant\Models\User; +use Mod\Tenant\Models\Workspace; + +uses(\Illuminate\Foundation\Testing\RefreshDatabase::class); + +beforeEach(function () { + Cache::flush(); + + $this->user = User::factory()->create(); + $this->workspace = Workspace::factory()->create(); + $this->workspace->users()->attach($this->user->id, [ + 'role' => 'owner', + 'is_default' => true, + ]); + + $result = ApiKey::generate( + $this->workspace->id, + $this->user->id, + 'MCP Server Detail Key', + [ApiKey::SCOPE_READ, ApiKey::SCOPE_WRITE] + ); + + $this->plainKey = $result['plain_key']; + + $this->serverId = 'test-detail-server'; + $this->serverDir = resource_path('mcp/servers'); + $this->serverFile = $this->serverDir.'/'.$this->serverId.'.yaml'; + + if (! is_dir($this->serverDir)) { + mkdir($this->serverDir, 0777, true); + } + + file_put_contents($this->serverFile, <<<YAML +id: test-detail-server +name: Test Detail Server +status: available +tools: + - name: search + description: Search records + inputSchema: + type: object + properties: + query: + type: string + required: + - query +resources: + - uri: test-detail-server://documents/welcome + path: documents/welcome + name: welcome + content: + message: Hello from the server detail endpoint + version: 2 +YAML); + + app()->instance(ToolVersionService::class, new class + { + public function getLatestVersion(string $serverId, string $toolName): object + { + return (object) [ + 'version' => '2.1.0', + 'is_deprecated' => false, + 'input_schema' => [ + 'type' => 'object', + 'properties' => [ + 'query' => [ + 'type' => 'string', + ], + ], + 'required' => ['query'], + ], + ]; + } + }); +}); + +afterEach(function () { + Cache::flush(); + + if (isset($this->serverFile) && is_file($this->serverFile)) { + unlink($this->serverFile); + } + + if (isset($this->serverDir) && is_dir($this->serverDir)) { + @rmdir($this->serverDir); + } + + $mcpDir = dirname($this->serverDir ?? ''); + if (is_dir($mcpDir)) { + @rmdir($mcpDir); + } +}); + +it('includes tool versions and resource content on server detail requests when requested', function () { + $response = $this->getJson('/api/mcp/servers/test-detail-server?include_versions=1&include_content=1', [ + 'Authorization' => "Bearer {$this->plainKey}", + ]); + + $response->assertOk(); + $response->assertJsonPath('id', 'test-detail-server'); + $response->assertJsonPath('tools.0.name', 'search'); + $response->assertJsonPath('tools.0.versioning.latest_version', '2.1.0'); + $response->assertJsonPath('tools.0.inputSchema.required.0', 'query'); + $response->assertJsonPath('resources.0.uri', 'test-detail-server://documents/welcome'); + $response->assertJsonPath('resources.0.content.message', 'Hello from the server detail endpoint'); + $response->assertJsonPath('resources.0.content.version', 2); + $response->assertJsonPath('tool_count', 1); + $response->assertJsonPath('resource_count', 1); +}); diff --git a/src/php/src/Api/Tests/Feature/OpenApiDocumentationComprehensiveTest.php b/src/php/src/Api/Tests/Feature/OpenApiDocumentationComprehensiveTest.php index 3669b87..06e7cac 100644 --- a/src/php/src/Api/Tests/Feature/OpenApiDocumentationComprehensiveTest.php +++ b/src/php/src/Api/Tests/Feature/OpenApiDocumentationComprehensiveTest.php @@ -10,6 +10,7 @@ use Core\Api\Documentation\Extension; use Core\Api\Documentation\Extensions\ApiKeyAuthExtension; use Core\Api\Documentation\Extensions\RateLimitExtension; +use Core\Api\Documentation\Extensions\SunsetExtension; use Core\Api\Documentation\Extensions\WorkspaceHeaderExtension; use Core\Api\Documentation\OpenApiBuilder; use Core\Api\RateLimit\RateLimit; @@ -152,6 +153,26 @@ expect($operation['operationId'])->toBe('testScanItemsIndex'); }); + it('makes duplicate operation IDs unique', function () { + RouteFacade::prefix('api') + ->middleware('api') + ->group(function () { + RouteFacade::get('/duplicate-id/dup-one', fn () => response()->json([])); + RouteFacade::get('/duplicate-id/dup_one', fn () => response()->json([])); + }); + + config(['api-docs.routes.include' => ['api/*']]); + + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $first = $spec['paths']['/api/duplicate-id/dup-one']['get']['operationId']; + $second = $spec['paths']['/api/duplicate-id/dup_one']['get']['operationId']; + + expect($first)->not->toBe($second); + expect($second)->toEndWith('_2'); + }); + it('generates summary from route name', function () { $builder = new OpenApiBuilder; $spec = $builder->build(); @@ -175,6 +196,116 @@ }); }); +// ───────────────────────────────────────────────────────────────────────────── +// Application Endpoint Parameter Docs +// ───────────────────────────────────────────────────────────────────────────── + +describe('Application Endpoint Parameter Docs', function () { + it('documents the SEO report url query parameter', function () { + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $operation = $spec['paths']['/api/seo/report']['get']; + $urlParam = collect($operation['parameters'] ?? [])->firstWhere('name', 'url'); + + expect($urlParam)->not->toBeNull(); + expect($urlParam['in'])->toBe('query'); + expect($urlParam['required'])->toBeTrue(); + expect($urlParam['schema']['format'])->toBe('uri'); + }); + + it('documents the pixel endpoint as binary for GET and no-content for POST', function () { + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $getOperation = $spec['paths']['/api/pixel/{pixelKey}']['get']; + $getResponse = $getOperation['responses']['200'] ?? []; + $getContent = $getResponse['content']['image/gif']['schema'] ?? null; + + expect($getContent)->toBe([ + 'type' => 'string', + 'format' => 'binary', + ]); + + $postOperation = $spec['paths']['/api/pixel/{pixelKey}']['post']; + $postResponse = $postOperation['responses']['204'] ?? []; + + expect($postResponse['description'] ?? null)->toBe('Accepted without a response body'); + expect($postResponse)->not->toHaveKey('content'); + }); + + it('documents MCP list query parameters', function () { + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $toolsOperation = $spec['paths']['/api/mcp/servers/{id}/tools']['get']; + $includeVersions = collect($toolsOperation['parameters'] ?? [])->firstWhere('name', 'include_versions'); + + expect($includeVersions)->not->toBeNull(); + expect($includeVersions['in'])->toBe('query'); + expect($includeVersions['schema']['type'])->toBe('boolean'); + + $resourcesOperation = $spec['paths']['/api/mcp/servers/{id}/resources']['get']; + $includeContent = collect($resourcesOperation['parameters'] ?? [])->firstWhere('name', 'include_content'); + + expect($includeContent)->not->toBeNull(); + expect($includeContent['in'])->toBe('query'); + expect($includeContent['schema']['type'])->toBe('boolean'); + + $serverOperation = $spec['paths']['/api/mcp/servers/{id}']['get']; + $serverIncludeVersions = collect($serverOperation['parameters'] ?? [])->firstWhere('name', 'include_versions'); + $serverIncludeContent = collect($serverOperation['parameters'] ?? [])->firstWhere('name', 'include_content'); + + expect($serverIncludeVersions)->not->toBeNull(); + expect($serverIncludeVersions['in'])->toBe('query'); + expect($serverIncludeVersions['schema']['type'])->toBe('boolean'); + + expect($serverIncludeContent)->not->toBeNull(); + expect($serverIncludeContent['in'])->toBe('query'); + expect($serverIncludeContent['schema']['type'])->toBe('boolean'); + }); + + it('lets explicit path parameter metadata override the generated entry', function () { + RouteFacade::prefix('api') + ->middleware('api') + ->group(function () { + RouteFacade::get('/test-scan/items/{id}/explicit', [TestExplicitPathParameterController::class, 'show']); + }); + + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $operation = $spec['paths']['/api/test-scan/items/{id}/explicit']['get']; + $parameters = $operation['parameters'] ?? []; + + expect($parameters)->toHaveCount(1); + + $idParam = collect($parameters)->firstWhere('name', 'id'); + + expect($idParam)->not->toBeNull(); + expect($idParam['in'])->toBe('path'); + expect($idParam['required'])->toBeTrue(); + expect($idParam['description'])->toBe('Explicit item identifier'); + }); + + it('documents the MCP tool call request body shape', function () { + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $operation = $spec['paths']['/api/mcp/tools/call']['post']; + $schema = $operation['requestBody']['content']['application/json']['schema'] ?? null; + + expect($schema)->not->toBeNull(); + expect($schema['type'])->toBe('object'); + expect($schema['properties'])->toHaveKey('server') + ->toHaveKey('tool') + ->toHaveKey('arguments') + ->toHaveKey('version'); + expect($schema['required'])->toBe(['server', 'tool']); + expect($schema['additionalProperties'])->toBeTrue(); + }); +}); + // ───────────────────────────────────────────────────────────────────────────── // ApiParameter Attribute Parsing // ───────────────────────────────────────────────────────────────────────────── @@ -323,9 +454,10 @@ enum: ['draft', 'published', 'archived'] 201 => 'Resource created', 204 => 'No content', 400 => 'Bad request', - 401 => 'Unauthorized', + 401 => 'Unauthorised', 403 => 'Forbidden', 404 => 'Not found', + 410 => 'Gone', 422 => 'Validation error', 429 => 'Too many requests', 500 => 'Internal server error', @@ -347,6 +479,29 @@ enum: ['draft', 'published', 'archived'] expect($response->resource)->toBe(TestJsonResource::class); }); + it('infers resource schema fields from JsonResource payloads', function () { + config(['api-docs.routes.include' => ['api/*']]); + config(['api-docs.routes.exclude' => []]); + + RouteFacade::prefix('api') + ->middleware('api') + ->group(function () { + RouteFacade::get('/test-scan/items/{id}', [TestOpenApiController::class, 'show']); + }); + + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $schema = $spec['paths']['/api/test-scan/items/{id}']['get']['responses']['200']['content']['application/json']['schema'] ?? null; + + expect($schema)->not->toBeNull(); + expect($schema['type'])->toBe('object'); + expect($schema['properties'])->toHaveKey('id') + ->toHaveKey('name'); + expect($schema['properties']['id']['type'])->toBe('integer'); + expect($schema['properties']['name']['type'])->toBe('string'); + }); + it('supports paginated flag', function () { $response = new ApiResponse( status: 200, @@ -649,7 +804,7 @@ enum: ['draft', 'published', 'archived'] // ───────────────────────────────────────────────────────────────────────────── describe('Error Response Documentation', function () { - it('documents 401 Unauthorized response', function () { + it('documents 401 Unauthorised response', function () { $extension = new ApiKeyAuthExtension; $spec = [ 'info' => [], @@ -711,6 +866,69 @@ enum: ['draft', 'published', 'archived'] }); }); +// ───────────────────────────────────────────────────────────────────────────── +// Sunset Documentation +// ───────────────────────────────────────────────────────────────────────────── + +describe('Sunset Documentation', function () { + it('registers deprecation headers in components', function () { + $extension = new SunsetExtension; + $spec = ['components' => []]; + + $result = $extension->extend($spec, []); + + expect($result['components']['headers'])->toHaveKey('deprecation') + ->toHaveKey('sunset') + ->toHaveKey('link') + ->toHaveKey('xapiwarn'); + }); + + it('marks sunset routes as deprecated and documents their response headers', function () { + RouteFacade::prefix('api') + ->middleware(['api', 'api.sunset:2025-06-01,/api/v2/legacy']) + ->group(function () { + RouteFacade::get('/sunset-test/legacy', fn () => response()->json(['ok' => true])) + ->name('sunset-test.legacy'); + }); + + config(['api-docs.routes.include' => ['api/*']]); + + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $operation = $spec['paths']['/api/sunset-test/legacy']['get']; + + expect($operation['deprecated'])->toBeTrue(); + expect($operation['responses']['200']['headers'])->toHaveKey('Deprecation') + ->toHaveKey('Sunset') + ->toHaveKey('X-API-Warn') + ->toHaveKey('Link'); + }); + + it('only documents the sunset headers that the middleware will emit', function () { + RouteFacade::prefix('api') + ->middleware(['api', 'api.sunset']) + ->group(function () { + RouteFacade::get('/sunset-test/plain', fn () => response()->json(['ok' => true])) + ->name('sunset-test.plain'); + }); + + config(['api-docs.routes.include' => ['api/*']]); + + $builder = new OpenApiBuilder; + $spec = $builder->build(); + + $operation = $spec['paths']['/api/sunset-test/plain']['get']; + $headers = $operation['responses']['200']['headers']; + + expect($operation['deprecated'])->toBeTrue(); + expect($headers)->toHaveKey('Deprecation') + ->toHaveKey('X-API-Warn'); + expect($headers)->not->toHaveKey('Sunset'); + expect($headers)->not->toHaveKey('Link'); + }); +}); + // ───────────────────────────────────────────────────────────────────────────── // Authentication Documentation // ───────────────────────────────────────────────────────────────────────────── @@ -1027,6 +1245,16 @@ public function publicMethod(): void {} public function hiddenMethod(): void {} } +/** + * Test controller with an explicit path parameter override. + */ +class TestExplicitPathParameterController +{ + #[ApiParameter('id', 'path', 'string', 'Explicit item identifier')] + #[ApiResponse(200, TestJsonResource::class, 'Item details')] + public function show(string $id): void {} +} + /** * Test tagged controller. */ diff --git a/src/php/src/Api/Tests/Feature/OpenApiDocumentationTest.php b/src/php/src/Api/Tests/Feature/OpenApiDocumentationTest.php index 69fc496..843e00f 100644 --- a/src/php/src/Api/Tests/Feature/OpenApiDocumentationTest.php +++ b/src/php/src/Api/Tests/Feature/OpenApiDocumentationTest.php @@ -57,6 +57,10 @@ public function test_api_response_generates_description_from_status(): void $response = new ApiResponse(404); $this->assertEquals('Not found', $response->getDescription()); + + $goneResponse = new ApiResponse(410); + + $this->assertEquals('Gone', $goneResponse->getDescription()); } public function test_api_security_attribute(): void diff --git a/src/php/src/Api/Tests/Feature/OpenApiVersionHeadersTest.php b/src/php/src/Api/Tests/Feature/OpenApiVersionHeadersTest.php new file mode 100644 index 0000000..40514e2 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/OpenApiVersionHeadersTest.php @@ -0,0 +1,43 @@ +<?php + +declare(strict_types=1); + +use Core\Api\Documentation\OpenApiBuilder; +use Illuminate\Support\Facades\Config; +use Illuminate\Support\Facades\Route as RouteFacade; + +beforeEach(function () { + Config::set('api.headers.include_version', true); + Config::set('api.headers.include_deprecation', true); + Config::set('api.versioning.deprecated', [1]); + Config::set('api.versioning.sunset', [ + 1 => '2025-06-01', + ]); + Config::set('api-docs.routes.include', ['api/*']); + Config::set('api-docs.routes.exclude', []); +}); + +it('documents version headers and version-driven deprecation on versioned routes', function () { + RouteFacade::prefix('api/v1') + ->middleware(['api', 'api.version:1']) + ->group(function () { + RouteFacade::get('/legacy-status', fn () => response()->json(['ok' => true])); + }); + + $spec = (new OpenApiBuilder)->build(); + + expect($spec['components']['headers']['xapiversion'] ?? null)->not->toBeNull(); + + $operation = $spec['paths']['/api/v1/legacy-status']['get']; + + expect($operation['deprecated'] ?? null)->toBeTrue(); + + foreach (['200', '400', '500'] as $status) { + $headers = $operation['responses'][$status]['headers'] ?? []; + + expect($headers)->toHaveKey('X-API-Version'); + expect($headers)->toHaveKey('Deprecation'); + expect($headers)->toHaveKey('Sunset'); + expect($headers)->toHaveKey('X-API-Warn'); + } +}); diff --git a/src/php/src/Api/Tests/Feature/PixelEndpointTest.php b/src/php/src/Api/Tests/Feature/PixelEndpointTest.php new file mode 100644 index 0000000..1e58d75 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/PixelEndpointTest.php @@ -0,0 +1,48 @@ +<?php + +declare(strict_types=1); + +use Illuminate\Support\Facades\Cache; + +beforeEach(function () { + Cache::flush(); +}); + +afterEach(function () { + Cache::flush(); +}); + +it('returns a transparent gif for get requests', function () { + $response = $this->get('/api/pixel/abc12345', [ + 'Origin' => 'https://example.com', + ]); + + $response->assertOk(); + $response->assertHeader('Content-Type', 'image/gif'); + $response->assertHeader('Access-Control-Allow-Origin', 'https://example.com'); + $response->assertHeader('X-RateLimit-Limit', '10000'); + $response->assertHeader('X-RateLimit-Remaining', '9999'); + + expect($response->getContent())->toBe(base64_decode('R0lGODlhAQABAPAAAP///wAAACH5BAAAAAAALAAAAAABAAEAAAICRAEAOw==')); +}); + +it('accepts post tracking requests without a body', function () { + $response = $this->post('/api/pixel/abc12345', [], [ + 'Origin' => 'https://example.com', + ]); + + $response->assertNoContent(); + $response->assertHeader('Access-Control-Allow-Origin', 'https://example.com'); + $response->assertHeader('X-RateLimit-Limit', '10000'); + $response->assertHeader('X-RateLimit-Remaining', '9999'); +}); + +it('handles preflight requests for public pixel tracking', function () { + $response = $this->call('OPTIONS', '/api/pixel/abc12345', [], [], [], [ + 'HTTP_ORIGIN' => 'https://example.com', + ]); + + $response->assertNoContent(); + $response->assertHeader('Access-Control-Allow-Origin', 'https://example.com'); + $response->assertHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS'); +}); diff --git a/src/php/src/Api/Tests/Feature/SeoReportEndpointTest.php b/src/php/src/Api/Tests/Feature/SeoReportEndpointTest.php new file mode 100644 index 0000000..4bcde61 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/SeoReportEndpointTest.php @@ -0,0 +1,70 @@ +<?php + +declare(strict_types=1); + +use Core\Api\Models\ApiKey; +use Core\Tenant\Models\User; +use Core\Tenant\Models\Workspace; +use Illuminate\Support\Facades\Http; + +uses(\Illuminate\Foundation\Testing\RefreshDatabase::class); + +beforeEach(function () { + $this->user = User::factory()->create(); + $this->workspace = Workspace::factory()->create(); + $this->workspace->users()->attach($this->user->id, [ + 'role' => 'owner', + 'is_default' => true, + ]); + + $result = ApiKey::generate( + $this->workspace->id, + $this->user->id, + 'SEO Key', + [ApiKey::SCOPE_READ] + ); + + $this->plainKey = $result['plain_key']; +}); + +it('returns a technical SEO report for a URL', function () { + Http::fake([ + 'https://example.com*' => Http::response(<<<'HTML' +<!doctype html> +<html lang="en"> +<head> + <meta charset="utf-8"> + <title>Example Product Landing Page + + + + + + + + + + +

Example Product Landing Page

+

Key Features

+ + +HTML, 200, [ + 'Content-Type' => 'text/html; charset=utf-8', + ]), + ]); + + $response = $this->getJson('/api/seo/report?url=https://example.com', [ + 'Authorization' => "Bearer {$this->plainKey}", + ]); + + $response->assertOk(); + $response->assertJsonPath('data.url', 'https://example.com'); + $response->assertJsonPath('data.status_code', 200); + $response->assertJsonPath('data.summary.title', 'Example Product Landing Page'); + $response->assertJsonPath('data.summary.description', 'A concise example description for the landing page.'); + $response->assertJsonPath('data.headings.h1', 1); + $response->assertJsonPath('data.open_graph.site_name', 'Example'); + $response->assertJsonPath('data.score', 100); + $response->assertJsonPath('data.issues', []); +}); diff --git a/src/php/src/Api/Tests/Feature/WebhookSecretRoutesTest.php b/src/php/src/Api/Tests/Feature/WebhookSecretRoutesTest.php new file mode 100644 index 0000000..79a9066 --- /dev/null +++ b/src/php/src/Api/Tests/Feature/WebhookSecretRoutesTest.php @@ -0,0 +1,28 @@ +getByName('api.webhooks.social.rotate-secret'); + $socialStatus = Route::getRoutes()->getByName('api.webhooks.social.status'); + $contentRotate = Route::getRoutes()->getByName('api.webhooks.content.rotate-secret'); + $contentStatus = Route::getRoutes()->getByName('api.webhooks.content.status'); + + expect($socialRotate)->not->toBeNull(); + expect($socialRotate->uri())->toBe('api/webhooks/social/{uuid}/secret/rotate'); + expect($socialRotate->methods())->toContain('POST'); + + expect($socialStatus)->not->toBeNull(); + expect($socialStatus->uri())->toBe('api/webhooks/social/{uuid}/secret'); + expect($socialStatus->methods())->toContain('GET'); + + expect($contentRotate)->not->toBeNull(); + expect($contentRotate->uri())->toBe('api/webhooks/content/{uuid}/secret/rotate'); + expect($contentRotate->methods())->toContain('POST'); + + expect($contentStatus)->not->toBeNull(); + expect($contentStatus->uri())->toBe('api/webhooks/content/{uuid}/secret'); + expect($contentStatus->methods())->toContain('GET'); +}); diff --git a/src/php/src/Api/config.php b/src/php/src/Api/config.php index 701ee76..d2a835d 100644 --- a/src/php/src/Api/config.php +++ b/src/php/src/Api/config.php @@ -1,5 +1,7 @@ [ + // HTTP timeout when fetching a page for analysis + 'timeout' => env('API_SEO_TIMEOUT', 10), + ], + /* |-------------------------------------------------------------------------- | Pagination diff --git a/src/php/src/Front/Api/ApiVersionService.php b/src/php/src/Front/Api/ApiVersionService.php index 5b889d2..2e23593 100644 --- a/src/php/src/Front/Api/ApiVersionService.php +++ b/src/php/src/Front/Api/ApiVersionService.php @@ -54,6 +54,69 @@ */ class ApiVersionService { + /** + * Normalise a list of API versions to unique positive integers. + * + * @param array $versions + * @return array + */ + protected function normaliseVersions(array $versions): array + { + $normalised = []; + + foreach ($versions as $version) { + if (! is_numeric($version)) { + continue; + } + + $version = (int) $version; + if ($version <= 0) { + continue; + } + + $normalised[] = $version; + } + + return array_values(array_unique($normalised)); + } + + /** + * Normalise sunset dates to an integer-keyed map. + * + * @param array $sunsets + * @return array + */ + protected function normaliseSunsetDates(array $sunsets): array + { + $normalised = []; + + foreach ($sunsets as $version => $date) { + if (! is_numeric($version)) { + continue; + } + + $version = (int) $version; + if ($version <= 0) { + continue; + } + + if ($date === null) { + continue; + } + + $date = trim((string) $date); + if ($date === '') { + continue; + } + + $normalised[$version] = $date; + } + + ksort($normalised); + + return $normalised; + } + /** * Get the current API version from the request. * @@ -116,7 +179,7 @@ public function isAtLeast(int $version, ?Request $request = null): bool public function isDeprecated(?Request $request = null): bool { $current = $this->current($request); - $deprecated = config('api.versioning.deprecated', []); + $deprecated = $this->deprecatedVersions(); return $current !== null && in_array($current, $deprecated, true); } @@ -144,7 +207,7 @@ public function latestVersion(): int */ public function supportedVersions(): array { - return config('api.versioning.supported', [1]); + return $this->normaliseVersions((array) config('api.versioning.supported', [1])); } /** @@ -154,7 +217,7 @@ public function supportedVersions(): array */ public function deprecatedVersions(): array { - return config('api.versioning.deprecated', []); + return $this->normaliseVersions((array) config('api.versioning.deprecated', [])); } /** @@ -164,7 +227,7 @@ public function deprecatedVersions(): array */ public function sunsetDates(): array { - return config('api.versioning.sunset', []); + return $this->normaliseSunsetDates((array) config('api.versioning.sunset', [])); } /** diff --git a/src/php/src/Front/Api/Middleware/ApiSunset.php b/src/php/src/Front/Api/Middleware/ApiSunset.php index c853f9a..8ec0829 100644 --- a/src/php/src/Front/Api/Middleware/ApiSunset.php +++ b/src/php/src/Front/Api/Middleware/ApiSunset.php @@ -12,101 +12,77 @@ namespace Core\Front\Api\Middleware; use Closure; +use DateTimeImmutable; +use DateTimeInterface; +use DateTimeZone; use Illuminate\Http\Request; use Symfony\Component\HttpFoundation\Response; /** * API Sunset Middleware. * - * Adds the HTTP Sunset header to responses to indicate when an endpoint - * will be deprecated or removed. - * - * The Sunset header is defined in RFC 8594 and indicates that a resource - * will become unresponsive at the specified date. - * - * ## Usage - * - * Apply to routes that will be sunset: - * - * ```php - * Route::middleware('api.sunset:2025-06-01')->group(function () { - * Route::get('/legacy-endpoint', LegacyController::class); - * }); - * ``` - * - * Or with a replacement link: - * - * ```php - * Route::middleware('api.sunset:2025-06-01,/api/v2/new-endpoint')->group(function () { - * Route::get('/old-endpoint', OldController::class); - * }); - * ``` - * - * ## Response Headers - * - * The middleware adds these headers: - * - Sunset: - * - Deprecation: true - * - Link: ; rel="successor-version" (if replacement provided) - * - * @see https://datatracker.ietf.org/doc/html/rfc8594 RFC 8594: The "Sunset" HTTP Header Field + * Adds deprecation headers to a route and optionally advertises a sunset + * date and successor endpoint. Existing header values are preserved so + * downstream middleware and handlers can keep their own warning metadata. */ class ApiSunset { /** * Handle an incoming request. * - * @param string $sunsetDate The sunset date (YYYY-MM-DD or RFC7231 format) - * @param string|null $replacement Optional replacement endpoint URL + * @param string $sunsetDate The sunset date (YYYY-MM-DD or RFC7231 format), or empty for deprecation-only + * @param string|null $replacement Optional successor endpoint URL */ - public function handle(Request $request, Closure $next, string $sunsetDate, ?string $replacement = null): Response + public function handle(Request $request, Closure $next, string $sunsetDate = '', ?string $replacement = null): Response { /** @var Response $response */ $response = $next($request); - // Convert date to RFC7231 format if needed - $formattedDate = $this->formatSunsetDate($sunsetDate); + if (! (bool) config('api.headers.include_deprecation', true)) { + return $response; + } - // Add Sunset header - $response->headers->set('Sunset', $formattedDate); + $response->headers->set('Deprecation', 'true', false); - // Add Deprecation header - $response->headers->set('Deprecation', 'true'); + if ($sunsetDate !== '') { + $response->headers->set('Sunset', $this->formatSunsetDate($sunsetDate), false); + } - // Add warning header - $version = $request->attributes->get('api_version', 'unknown'); - $response->headers->set( - 'X-API-Warn', - "This endpoint is deprecated and will be removed on {$sunsetDate}." - ); + if ($replacement !== null && $replacement !== '') { + $response->headers->set('Link', sprintf('<%s>; rel="successor-version"', $replacement), false); + } - // Add Link header for replacement if provided - if ($replacement !== null) { - $response->headers->set('Link', "<{$replacement}>; rel=\"successor-version\""); + $warning = 'This endpoint is deprecated.'; + if ($sunsetDate !== '') { + $warning = "This endpoint is deprecated and will be removed on {$sunsetDate}."; } + $response->headers->set('X-API-Warn', $warning, false); + return $response; } /** - * Format the sunset date to RFC7231 format. - * - * Accepts dates in YYYY-MM-DD format or already-formatted RFC7231 dates. + * Format the sunset date to RFC7231 format when possible. */ - protected function formatSunsetDate(string $date): string + protected function formatSunsetDate(string $sunsetDate): string { - // Check if already in RFC7231 format (contains comma, day name) - if (str_contains($date, ',')) { - return $date; + $sunsetDate = trim($sunsetDate); + if ($sunsetDate === '') { + return $sunsetDate; + } + + // Already RFC7231-style dates contain a comma, so preserve them. + if (str_contains($sunsetDate, ',')) { + return $sunsetDate; } try { - return (new \DateTimeImmutable($date)) - ->setTimezone(new \DateTimeZone('GMT')) - ->format(\DateTimeInterface::RFC7231); - } catch (\Exception) { - // If parsing fails, return as-is - return $date; + return (new DateTimeImmutable($sunsetDate)) + ->setTimezone(new DateTimeZone('GMT')) + ->format(DateTimeInterface::RFC7231); + } catch (\Throwable) { + return $sunsetDate; } } } diff --git a/src/php/src/Front/Api/Middleware/ApiVersion.php b/src/php/src/Front/Api/Middleware/ApiVersion.php index 52c659b..5bd073a 100644 --- a/src/php/src/Front/Api/Middleware/ApiVersion.php +++ b/src/php/src/Front/Api/Middleware/ApiVersion.php @@ -105,11 +105,16 @@ public function handle(Request $request, Closure $next, ?int $requiredVersion = /** @var Response $response */ $response = $next($request); - // Add version header to response - $response->headers->set('X-API-Version', (string) $version); + $includeVersionHeader = (bool) config('api.headers.include_version', true); + $includeDeprecationHeaders = (bool) config('api.headers.include_deprecation', true); + + // Add version header to response when enabled + if ($includeVersionHeader) { + $response->headers->set('X-API-Version', (string) $version); + } // Add deprecation headers if applicable - if (in_array($version, $deprecated, true)) { + if ($includeDeprecationHeaders && in_array($version, $deprecated, true)) { $response->headers->set('Deprecation', 'true'); $response->headers->set('X-API-Warn', "API version {$version} is deprecated. Please upgrade to v{$current}."); @@ -183,7 +188,9 @@ protected function versionFromHeader(Request $request): ?int return null; } - // Strip 'v' prefix if present + // Strip 'v' prefix and any optional parameters if present. + $header = trim($header); + $header = explode(';', $header, 2)[0]; $version = ltrim($header, 'vV'); if (is_numeric($version)) { @@ -202,9 +209,14 @@ protected function versionFromAcceptHeader(Request $request): ?int { $accept = $request->header('Accept', ''); - // Match vendor media type: application/vnd.{name}.v{n}+json - if (preg_match('#application/vnd\.[^.]+\.v(\d+)\+#', $accept, $matches)) { - return (int) $matches[1]; + foreach (preg_split('/\s*,\s*/', $accept, -1, PREG_SPLIT_NO_EMPTY) ?: [] as $mediaType) { + // Strip media-type parameters before matching the vendor suffix. + $mediaType = explode(';', trim($mediaType), 2)[0]; + + // Match vendor media type: application/vnd.{name}.v{n}+json + if (preg_match('#^application/vnd\.[^.]+\.v(\d+)\+#i', $mediaType, $matches)) { + return (int) $matches[1]; + } } return null; @@ -217,6 +229,12 @@ protected function versionFromAcceptHeader(Request $request): ?int */ protected function unsupportedVersion(int $requested, array $supported, int $current): Response { + $headers = []; + + if ((bool) config('api.headers.include_version', true)) { + $headers['X-API-Version'] = (string) $current; + } + return response()->json([ 'error' => 'unsupported_api_version', 'message' => "API version {$requested} is not supported.", @@ -224,9 +242,7 @@ protected function unsupportedVersion(int $requested, array $supported, int $cur 'supported_versions' => $supported, 'current_version' => $current, 'hint' => 'Use Accept-Version header or URL prefix (e.g., /api/v1/) to specify version.', - ], 400, [ - 'X-API-Version' => (string) $current, - ]); + ], 400, $headers); } /** @@ -234,13 +250,17 @@ protected function unsupportedVersion(int $requested, array $supported, int $cur */ protected function versionTooLow(int $requested, int $required): Response { + $headers = []; + + if ((bool) config('api.headers.include_version', true)) { + $headers['X-API-Version'] = (string) $requested; + } + return response()->json([ 'error' => 'api_version_too_low', 'message' => "This endpoint requires API version {$required} or higher.", 'requested_version' => $requested, 'minimum_version' => $required, - ], 400, [ - 'X-API-Version' => (string) $requested, - ]); + ], 400, $headers); } } diff --git a/src/php/src/Front/Api/README.md b/src/php/src/Front/Api/README.md index 45689c6..98aec29 100644 --- a/src/php/src/Front/Api/README.md +++ b/src/php/src/Front/Api/README.md @@ -145,7 +145,7 @@ VersionedRoutes::versions([1, 2], function () { // Deprecated version with sunset VersionedRoutes::v1() - ->deprecated('2025-06-01') + ->deprecated('2025-06-01', '/api/v2/new-endpoint') ->routes(function () { Route::get('/legacy', LegacyController::class); }); diff --git a/src/php/src/Front/Api/VersionedRoutes.php b/src/php/src/Front/Api/VersionedRoutes.php index 5ebe22f..4239435 100644 --- a/src/php/src/Front/Api/VersionedRoutes.php +++ b/src/php/src/Front/Api/VersionedRoutes.php @@ -72,7 +72,7 @@ * * ```php * VersionedRoutes::v1() - * ->deprecated('2025-06-01') + * ->deprecated('2025-06-01', '/api/v2/new-endpoint') * ->routes(function () { * Route::get('/legacy', ...); * }); @@ -86,6 +86,11 @@ class VersionedRoutes protected ?string $sunsetDate = null; + /** + * @var string|null + */ + protected ?string $replacement = null; + protected bool $isDeprecated = false; /** @@ -178,11 +183,13 @@ public function withPrefix(): static * Mark this version as deprecated. * * @param string|null $sunsetDate Optional sunset date (YYYY-MM-DD or RFC7231 format) + * @param string|null $replacement Optional replacement endpoint URL */ - public function deprecated(?string $sunsetDate = null): static + public function deprecated(?string $sunsetDate = null, ?string $replacement = null): static { $this->isDeprecated = true; $this->sunsetDate = $sunsetDate; + $this->replacement = $replacement; return $this; } @@ -239,8 +246,18 @@ protected function buildMiddleware(): array { $middleware = ["api.version:{$this->version}"]; - if ($this->isDeprecated && $this->sunsetDate) { - $middleware[] = "api.sunset:{$this->sunsetDate}"; + if ($this->isDeprecated) { + if ($this->sunsetDate !== null && $this->sunsetDate !== '') { + if ($this->replacement !== null && $this->replacement !== '') { + $middleware[] = "api.sunset:{$this->sunsetDate},{$this->replacement}"; + } else { + $middleware[] = "api.sunset:{$this->sunsetDate}"; + } + } elseif ($this->replacement !== null && $this->replacement !== '') { + $middleware[] = "api.sunset:,$this->replacement"; + } else { + $middleware[] = 'api.sunset'; + } } return array_merge($middleware, $this->middleware); diff --git a/src/php/src/Website/Api/Controllers/DocsController.php b/src/php/src/Website/Api/Controllers/DocsController.php index b1140da..6715927 100644 --- a/src/php/src/Website/Api/Controllers/DocsController.php +++ b/src/php/src/Website/Api/Controllers/DocsController.php @@ -65,6 +65,11 @@ public function redoc(): View return view('api::redoc'); } + public function stoplight(): View + { + return view('api::stoplight'); + } + public function openapi(OpenApiGenerator $generator): JsonResponse { return response()->json($generator->generate()); diff --git a/src/php/src/Website/Api/Routes/web.php b/src/php/src/Website/Api/Routes/web.php index a545b4a..95713e8 100644 --- a/src/php/src/Website/Api/Routes/web.php +++ b/src/php/src/Website/Api/Routes/web.php @@ -28,6 +28,9 @@ // ReDoc (three-panel API reference) Route::get('/redoc', [DocsController::class, 'redoc'])->name('api.redoc'); +// Stoplight Elements API reference +Route::get('/stoplight', [DocsController::class, 'stoplight'])->name('api.stoplight'); + // OpenAPI spec (rate limited - expensive to generate) Route::get('/openapi.json', [DocsController::class, 'openapi']) ->middleware('throttle:60,1') diff --git a/src/php/src/Website/Api/View/Blade/layouts/docs.blade.php b/src/php/src/Website/Api/View/Blade/layouts/docs.blade.php index 5500522..7fa53ab 100644 --- a/src/php/src/Website/Api/View/Blade/layouts/docs.blade.php +++ b/src/php/src/Website/Api/View/Blade/layouts/docs.blade.php @@ -81,7 +81,7 @@ class="w-full sm:w-80 text-sm bg-white text-zinc-400 inline-flex items-center ju
diff --git a/src/php/src/Website/Api/View/Blade/stoplight.blade.php b/src/php/src/Website/Api/View/Blade/stoplight.blade.php new file mode 100644 index 0000000..ed4b32c --- /dev/null +++ b/src/php/src/Website/Api/View/Blade/stoplight.blade.php @@ -0,0 +1,23 @@ +@extends('layouts::docs') + +@section('title', 'Stoplight') +@section('description', 'Stoplight Elements API reference for the Core API.') + +@section('content') +
+ +
+@endsection + +@push('head') + +@endpush + +@push('scripts') + +@endpush diff --git a/src/php/tests/Feature/ApiSunsetTest.php b/src/php/tests/Feature/ApiSunsetTest.php new file mode 100644 index 0000000..50255db --- /dev/null +++ b/src/php/tests/Feature/ApiSunsetTest.php @@ -0,0 +1,88 @@ +handle($request, fn () => new Response('OK')); + + expect($response->headers->get('Deprecation'))->toBe('true'); + expect($response->headers->has('Sunset'))->toBeFalse(); + expect($response->headers->has('Link'))->toBeFalse(); + expect($response->headers->get('X-API-Warn'))->toBe('This endpoint is deprecated.'); +}); + +it('adds a replacement link without a sunset date', function () { + Config::set('api.headers.include_deprecation', true); + + $middleware = new ApiSunset(); + $request = Request::create('/old-endpoint', 'GET'); + + $response = $middleware->handle($request, fn () => new Response('OK'), '', '/api/v4/users'); + + expect($response->headers->get('Deprecation'))->toBe('true'); + expect($response->headers->has('Sunset'))->toBeFalse(); + expect($response->headers->get('Link'))->toBe('; rel="successor-version"'); + expect($response->headers->get('X-API-Warn'))->toBe('This endpoint is deprecated.'); +}); + +it('preserves existing deprecation headers while appending sunset metadata', function () { + Config::set('api.headers.include_deprecation', true); + + $middleware = new ApiSunset(); + $request = Request::create('/legacy-endpoint', 'GET'); + + $response = $middleware->handle($request, function () { + $response = new Response('OK'); + $response->headers->set('Deprecation', 'false'); + $response->headers->set('Sunset', 'Wed, 01 Jan 2025 00:00:00 GMT'); + $response->headers->set('Link', '; rel="help"'); + $response->headers->set('X-API-Warn', 'Existing warning'); + + return $response; + }, '2025-06-01', '/api/v2/users'); + + expect($response->headers->all('Deprecation'))->toHaveCount(2); + expect($response->headers->all('Sunset'))->toHaveCount(2); + expect($response->headers->all('Link'))->toHaveCount(2); + expect($response->headers->all('X-API-Warn'))->toHaveCount(2); + expect($response->headers->all('Link'))->toContain('; rel="help"'); + expect($response->headers->all('Link'))->toContain('; rel="successor-version"'); +}); + +it('formats the sunset date and keeps the replacement link', function () { + Config::set('api.headers.include_deprecation', true); + + $middleware = new ApiSunset(); + $request = Request::create('/legacy-endpoint', 'GET'); + + $response = $middleware->handle($request, fn () => new Response('OK'), '2025-06-01', '/api/v2/users'); + + expect($response->headers->get('Deprecation'))->toBe('true'); + expect($response->headers->get('Sunset'))->toBe('Sun, 01 Jun 2025 00:00:00 GMT'); + expect($response->headers->get('Link'))->toBe('; rel="successor-version"'); + expect($response->headers->get('X-API-Warn'))->toBe('This endpoint is deprecated and will be removed on 2025-06-01.'); +}); + +it('skips deprecation headers when they are disabled in configuration', function () { + Config::set('api.headers.include_deprecation', false); + + $middleware = new ApiSunset(); + $request = Request::create('/legacy-endpoint', 'GET'); + + $response = $middleware->handle($request, fn () => new Response('OK'), '2025-06-01', '/api/v2/users'); + + expect($response->headers->has('Deprecation'))->toBeFalse(); + expect($response->headers->has('Sunset'))->toBeFalse(); + expect($response->headers->has('Link'))->toBeFalse(); + expect($response->headers->has('X-API-Warn'))->toBeFalse(); +}); diff --git a/src/php/tests/Feature/ApiVersionHeadersTest.php b/src/php/tests/Feature/ApiVersionHeadersTest.php new file mode 100644 index 0000000..cd5a645 --- /dev/null +++ b/src/php/tests/Feature/ApiVersionHeadersTest.php @@ -0,0 +1,45 @@ + '2025-06-01', + ]); + Config::set('api.headers.include_version', true); + Config::set('api.headers.include_deprecation', true); +}); + +it('skips the api version header when it is disabled in configuration', function () { + Config::set('api.headers.include_version', false); + + $middleware = new ApiVersion(); + $request = Request::create('/api/users', 'GET'); + + $response = $middleware->handle($request, fn () => new Response('OK')); + + expect($response->headers->has('X-API-Version'))->toBeFalse(); +}); + +it('skips deprecation headers when they are disabled in configuration', function () { + Config::set('api.headers.include_deprecation', false); + + $middleware = new ApiVersion(); + $request = Request::create('/api/v1/users', 'GET'); + + $response = $middleware->handle($request, fn () => new Response('OK')); + + expect($response->headers->get('X-API-Version'))->toBe('1'); + expect($response->headers->has('Deprecation'))->toBeFalse(); + expect($response->headers->has('Sunset'))->toBeFalse(); + expect($response->headers->has('X-API-Warn'))->toBeFalse(); +}); diff --git a/src/php/tests/Feature/ApiVersionParsingTest.php b/src/php/tests/Feature/ApiVersionParsingTest.php new file mode 100644 index 0000000..59b5656 --- /dev/null +++ b/src/php/tests/Feature/ApiVersionParsingTest.php @@ -0,0 +1,45 @@ +headers->set('Accept-Version', 'v2; q=1.0'); + + $response = $middleware->handle($request, fn () => new Response('OK')); + + expect($response->headers->get('X-API-Version'))->toBe('2'); + expect($request->attributes->get('api_version'))->toBe(2); + expect($request->attributes->get('api_version_string'))->toBe('v2'); +}); + +it('resolves the api version from a vendor accept header inside a list', function () { + $middleware = new ApiVersion(); + $request = Request::create('/api/users', 'GET'); + $request->headers->set( + 'Accept', + 'text/html;q=0.8, application/json, application/vnd.hosthub.v2+json; charset=utf-8' + ); + + $response = $middleware->handle($request, fn () => new Response('OK')); + + expect($response->headers->get('X-API-Version'))->toBe('2'); + expect($request->attributes->get('api_version'))->toBe(2); + expect($request->attributes->get('api_version_string'))->toBe('v2'); +}); diff --git a/src/php/tests/Feature/ApiVersionServiceTest.php b/src/php/tests/Feature/ApiVersionServiceTest.php new file mode 100644 index 0000000..b9b299d --- /dev/null +++ b/src/php/tests/Feature/ApiVersionServiceTest.php @@ -0,0 +1,40 @@ + '2025-06-01', + '02' => '2025-12-31', + 'ignored' => '2026-01-01', + 0 => '2024-01-01', + -1 => '2024-06-01', + 3 => '', + ]); + + $versions = new ApiVersionService(); + + expect($versions->supportedVersions())->toBe([1, 2]); + expect($versions->deprecatedVersions())->toBe([1, 3]); + expect($versions->sunsetDates())->toBe([ + 1 => '2025-06-01', + 2 => '2025-12-31', + ]); + expect($versions->isSupported(1))->toBeTrue(); + expect($versions->isSupported(2))->toBeTrue(); + expect($versions->isSupported(3))->toBeFalse(); + expect($versions->isDeprecated())->toBeFalse(); +}); diff --git a/src/php/tests/Feature/VersionedRoutesTest.php b/src/php/tests/Feature/VersionedRoutesTest.php new file mode 100644 index 0000000..1aef855 --- /dev/null +++ b/src/php/tests/Feature/VersionedRoutesTest.php @@ -0,0 +1,62 @@ +buildRouteAttributes(); + } + }; + + $attributes = $routes->deprecated('2025-06-01', '/api/v3/users')->attributes(); + + expect($attributes)->toHaveKey('middleware'); + expect($attributes['middleware'])->toContain('api.version:2'); + expect($attributes['middleware'])->toContain('api.sunset:2025-06-01,/api/v3/users'); +}); + +it('preserves the existing deprecated signature without a replacement url', function () { + $routes = new class (1) extends VersionedRoutes { + public function attributes(): array + { + return $this->buildRouteAttributes(); + } + }; + + $attributes = $routes->deprecated('2025-06-01')->attributes(); + + expect($attributes['middleware'])->toContain('api.sunset:2025-06-01'); + expect($attributes['middleware'])->not->toContain('api.sunset:2025-06-01,/api/v3/users'); +}); + +it('keeps deprecated routes active without a sunset date', function () { + $routes = new class (3) extends VersionedRoutes { + public function attributes(): array + { + return $this->buildRouteAttributes(); + } + }; + + $attributes = $routes->deprecated()->attributes(); + + expect($attributes['middleware'])->toContain('api.version:3'); + expect($attributes['middleware'])->toContain('api.sunset'); +}); + +it('passes a replacement url through deprecated versioned routes without a sunset date', function () { + $routes = new class (4) extends VersionedRoutes { + public function attributes(): array + { + return $this->buildRouteAttributes(); + } + }; + + $attributes = $routes->deprecated(null, '/api/v4/users')->attributes(); + + expect($attributes['middleware'])->toContain('api.version:4'); + expect($attributes['middleware'])->toContain('api.sunset:,/api/v4/users'); +}); diff --git a/sse.go b/sse.go index 9adf7ee..8430f34 100644 --- a/sse.go +++ b/sse.go @@ -6,25 +6,37 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "sync" "github.com/gin-gonic/gin" ) +// defaultSSEPath is the URL path where the SSE endpoint is mounted. +const defaultSSEPath = "/events" + // SSEBroker manages Server-Sent Events connections and broadcasts events // to subscribed clients. Clients connect via a GET endpoint and receive // a streaming text/event-stream response. Each client may optionally // subscribe to a specific channel via the ?channel= query parameter. +// +// Example: +// +// broker := api.NewSSEBroker() +// engine.GET("/events", broker.Handler()) type SSEBroker struct { mu sync.RWMutex + wg sync.WaitGroup clients map[*sseClient]struct{} } // sseClient represents a single connected SSE consumer. type sseClient struct { - channel string - events chan sseEvent - done chan struct{} + channel string + events chan sseEvent + done chan struct{} + doneOnce sync.Once + eventsOnce sync.Once } // sseEvent is an internal representation of a single SSE message. @@ -34,15 +46,48 @@ type sseEvent struct { } // NewSSEBroker creates a ready-to-use SSE broker. +// +// Example: +// +// broker := api.NewSSEBroker() func NewSSEBroker() *SSEBroker { return &SSEBroker{ clients: make(map[*sseClient]struct{}), } } +// normaliseSSEPath coerces custom SSE paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseSSEPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultSSEPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultSSEPath + } + + return path +} + +// resolveSSEPath returns the configured SSE path or the default path when +// no override has been provided. +func resolveSSEPath(path string) string { + if strings.TrimSpace(path) == "" { + return defaultSSEPath + } + return normaliseSSEPath(path) +} + // Publish sends an event to all clients subscribed to the given channel. // Clients subscribed to an empty channel (no ?channel= param) receive // events on every channel. The data value is JSON-encoded before sending. +// +// Example: +// +// broker.Publish("system", "ready", map[string]any{"status": "ok"}) func (b *SSEBroker) Publish(channel, event string, data any) { encoded, err := json.Marshal(data) if err != nil { @@ -60,6 +105,11 @@ func (b *SSEBroker) Publish(channel, event string, data any) { for client := range b.clients { // Send to clients on the matching channel, or clients with no channel filter. if client.channel == "" || client.channel == channel { + select { + case <-client.done: + continue + default: + } select { case client.events <- msg: case <-client.done: @@ -73,6 +123,10 @@ func (b *SSEBroker) Publish(channel, event string, data any) { // Handler returns a Gin handler for the SSE endpoint. Clients connect with // a GET request and receive events as text/event-stream. An optional // ?channel= query parameter subscribes the client to a specific channel. +// +// Example: +// +// engine.GET("/events", broker.Handler()) func (b *SSEBroker) Handler() gin.HandlerFunc { return func(c *gin.Context) { channel := c.Query("channel") @@ -85,13 +139,15 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { b.mu.Lock() b.clients[client] = struct{}{} + b.wg.Add(1) b.mu.Unlock() defer func() { - close(client.done) b.mu.Lock() + client.signalDone() delete(b.clients, client) b.mu.Unlock() + b.wg.Done() }() // Set SSE headers. @@ -108,7 +164,20 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { select { case <-ctx.Done(): return - case evt := <-client.events: + case <-client.done: + return + default: + } + + select { + case <-ctx.Done(): + return + case <-client.done: + return + case evt, ok := <-client.events: + if !ok { + return + } _, err := fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", evt.Event, evt.Data) if err != nil { return @@ -123,23 +192,43 @@ func (b *SSEBroker) Handler() gin.HandlerFunc { } // ClientCount returns the number of currently connected SSE clients. +// +// Example: +// +// n := broker.ClientCount() func (b *SSEBroker) ClientCount() int { b.mu.RLock() defer b.mu.RUnlock() return len(b.clients) } -// Drain closes all connected clients by writing an empty response. -// Useful for graceful shutdown. +// Drain signals all connected clients to disconnect and waits for their +// handler goroutines to exit. Useful for graceful shutdown. +// +// Example: +// +// broker.Drain() func (b *SSEBroker) Drain() { b.mu.Lock() - defer b.mu.Unlock() for client := range b.clients { - select { - case <-client.done: - default: - // Write EOF to trigger client disconnect via their event loop. - close(client.events) - } + client.signalDone() + client.closeEvents() } + b.mu.Unlock() + + b.wg.Wait() +} + +// signalDone closes the client done channel once. +func (c *sseClient) signalDone() { + c.doneOnce.Do(func() { + close(c.done) + }) +} + +// closeEvents closes the client event channel once. +func (c *sseClient) closeEvents() { + c.eventsOnce.Do(func() { + close(c.events) + }) } diff --git a/sse_test.go b/sse_test.go index 7467b38..cfa950c 100644 --- a/sse_test.go +++ b/sse_test.go @@ -4,6 +4,8 @@ package api_test import ( "bufio" + "context" + "net" "net/http" "net/http/httptest" "strings" @@ -46,6 +48,44 @@ func TestWithSSE_Good_EndpointExists(t *testing.T) { } } +func TestWithSSE_Good_CustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker), api.WithSSEPath("/stream")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/stream") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + ct := resp.Header.Get("Content-Type") + if !strings.HasPrefix(ct, "text/event-stream") { + t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct) + } + + notFoundResp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request to default SSE path failed: %v", err) + } + defer notFoundResp.Body.Close() + + if notFoundResp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404 at default /events when custom path is configured, got %d", notFoundResp.StatusCode) + } +} + func TestWithSSE_Good_ReceivesPublishedEvent(t *testing.T) { gin.SetMode(gin.TestMode) @@ -202,6 +242,67 @@ func TestWithSSE_Good_CombinesWithOtherMiddleware(t *testing.T) { } } +func TestWithSSE_Good_WithResponseMetaStillStreamsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + api.WithSSE(broker), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if ct := resp.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/event-stream") { + t.Fatalf("expected Content-Type starting with text/event-stream, got %q", ct) + } + if reqID := resp.Header.Get("X-Request-ID"); reqID == "" { + t.Fatal("expected X-Request-ID header from RequestID middleware") + } + + waitForClients(t, broker, 1) + + broker.Publish("test", "greeting", map[string]string{"msg": "hello"}) + + scanner := bufio.NewScanner(resp.Body) + var eventLine string + + deadline := time.After(3 * time.Second) + done := make(chan struct{}) + + go func() { + defer close(done) + for scanner.Scan() { + line := scanner.Text() + if after, ok := strings.CutPrefix(line, "event: "); ok { + eventLine = after + return + } + } + }() + + select { + case <-done: + case <-deadline: + t.Fatal("timed out waiting for SSE event with response meta enabled") + } + + if eventLine != "greeting" { + t.Fatalf("expected event=%q, got %q", "greeting", eventLine) + } +} + func TestWithSSE_Good_MultipleClients(t *testing.T) { gin.SetMode(gin.TestMode) @@ -269,6 +370,58 @@ func TestWithSSE_Good_MultipleClients(t *testing.T) { wg.Wait() } +func TestWithSSE_Good_DrainDisconnectsClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + + waitForClients(t, broker, 1) + + streamDone := make(chan struct{}) + go func() { + defer close(streamDone) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + } + }() + + drainDone := make(chan struct{}) + go func() { + broker.Drain() + close(drainDone) + }() + + select { + case <-drainDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE drain to complete") + } + + select { + case <-streamDone: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for SSE client to disconnect") + } + + if got := broker.ClientCount(); got != 0 { + t.Fatalf("expected 0 connected SSE clients after drain, got %d", got) + } + + _ = resp.Body.Close() +} + // ── No SSE broker ──────────────────────────────────────────────────────── func TestNoSSEBroker_Good(t *testing.T) { @@ -287,6 +440,63 @@ func TestNoSSEBroker_Good(t *testing.T) { } } +func TestWithSSE_Good_EngineShutdownDrainsClients(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to allocate listener: %v", err) + } + addr := ln.Addr().String() + _ = ln.Close() + + e, err := api.New(api.WithAddr(addr), api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + errCh <- e.Serve(ctx) + }() + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + _ = conn.Close() + break + } + time.Sleep(50 * time.Millisecond) + } + + resp, err := http.Get("http://" + addr + "/events") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + waitForClients(t, broker, 1) + + cancel() + + select { + case serveErr := <-errCh: + if serveErr != nil { + t.Fatalf("Serve returned unexpected error: %v", serveErr) + } + case <-time.After(5 * time.Second): + t.Fatal("Serve did not return within 5 seconds after context cancellation") + } + + if got := broker.ClientCount(); got != 0 { + t.Fatalf("expected SSE broker to drain all clients, got %d", got) + } +} + // ── Helpers ────────────────────────────────────────────────────────────── // waitForClients polls the broker until the expected number of clients diff --git a/sunset.go b/sunset.go new file mode 100644 index 0000000..24d2709 --- /dev/null +++ b/sunset.go @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// ApiSunset returns middleware that marks a route or group as deprecated. +// +// The middleware appends standard deprecation headers to every response: +// Deprecation, optional Sunset, optional Link, and X-API-Warn. Existing header +// values are preserved so downstream middleware and handlers can keep their own +// link relations or warning metadata. +// +// Example: +// +// rg.Use(api.ApiSunset("2025-06-01", "/api/v2/users")) +func ApiSunset(sunsetDate, replacement string) gin.HandlerFunc { + sunsetDate = strings.TrimSpace(sunsetDate) + replacement = strings.TrimSpace(replacement) + formatted := formatSunsetDate(sunsetDate) + warning := "This endpoint is deprecated." + if sunsetDate != "" { + warning = "This endpoint is deprecated and will be removed on " + sunsetDate + "." + } + + return func(c *gin.Context) { + c.Next() + + c.Writer.Header().Add("Deprecation", "true") + if formatted != "" { + c.Writer.Header().Add("Sunset", formatted) + } + if replacement != "" { + c.Writer.Header().Add("Link", "<"+replacement+">; rel=\"successor-version\"") + } + c.Writer.Header().Add("X-API-Warn", warning) + } +} + +func formatSunsetDate(sunsetDate string) string { + sunsetDate = strings.TrimSpace(sunsetDate) + if sunsetDate == "" { + return "" + } + if strings.Contains(sunsetDate, ",") { + return sunsetDate + } + + parsed, err := time.Parse("2006-01-02", sunsetDate) + if err != nil { + return sunsetDate + } + + return parsed.UTC().Format(http.TimeFormat) +} diff --git a/sunset_test.go b/sunset_test.go new file mode 100644 index 0000000..1348be7 --- /dev/null +++ b/sunset_test.go @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + + api "dappco.re/go/core/api" +) + +type sunsetStubGroup struct{} + +func (sunsetStubGroup) Name() string { return "legacy" } +func (sunsetStubGroup) BasePath() string { return "/legacy" } +func (sunsetStubGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/status", func(c *gin.Context) { + c.JSON(http.StatusOK, api.OK("ok")) + }) +} + +type sunsetLinkStubGroup struct{} + +func (sunsetLinkStubGroup) Name() string { return "legacy-link" } +func (sunsetLinkStubGroup) BasePath() string { return "/legacy-link" } +func (sunsetLinkStubGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/status", func(c *gin.Context) { + c.Header("Link", "; rel=\"help\"") + c.JSON(http.StatusOK, api.OK("ok")) + }) +} + +type sunsetHeaderStubGroup struct{} + +func (sunsetHeaderStubGroup) Name() string { return "legacy-headers" } +func (sunsetHeaderStubGroup) BasePath() string { return "/legacy-headers" } +func (sunsetHeaderStubGroup) RegisterRoutes(rg *gin.RouterGroup) { + rg.GET("/status", func(c *gin.Context) { + c.Header("Deprecation", "false") + c.Header("Sunset", "Wed, 01 Jan 2025 00:00:00 GMT") + c.Header("X-API-Warn", "Existing warning") + c.Header("Link", "; rel=\"help\"") + c.JSON(http.StatusOK, api.OK("ok")) + }) +} + +func TestWithSunset_Good_AddsDeprecationHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSunset("2025-06-01", "/api/v2/status")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + e.Register(sunsetStubGroup{}) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/legacy/status", nil) + e.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + if got := w.Header().Get("Deprecation"); got != "true" { + t.Fatalf("expected Deprecation=true, got %q", got) + } + if got := w.Header().Get("Sunset"); got != "Sun, 01 Jun 2025 00:00:00 GMT" { + t.Fatalf("expected formatted Sunset header, got %q", got) + } + if got := w.Header().Get("Link"); got != "; rel=\"successor-version\"" { + t.Fatalf("expected successor Link header, got %q", got) + } + if got := w.Header().Get("X-API-Warn"); got != "This endpoint is deprecated and will be removed on 2025-06-01." { + t.Fatalf("expected deprecation warning, got %q", got) + } +} + +func TestWithSunset_Good_PreservesExistingLinkHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSunset("2025-06-01", "/api/v2/status")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + e.Register(sunsetLinkStubGroup{}) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/legacy-link/status", nil) + e.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + links := w.Header().Values("Link") + if len(links) != 2 { + t.Fatalf("expected 2 Link header values, got %v", links) + } + if links[0] != "; rel=\"help\"" { + t.Fatalf("expected existing Link header to be preserved first, got %q", links[0]) + } + if links[1] != "; rel=\"successor-version\"" { + t.Fatalf("expected successor Link header to be appended, got %q", links[1]) + } +} + +func TestWithSunset_Good_PreservesExistingDeprecationHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSunset("2025-06-01", "/api/v2/status")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + e.Register(sunsetHeaderStubGroup{}) + + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/legacy-headers/status", nil) + e.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", w.Code) + } + + if got := w.Header().Values("Deprecation"); len(got) != 2 { + t.Fatalf("expected 2 Deprecation header values, got %v", got) + } + if got := w.Header().Values("Sunset"); len(got) != 2 { + t.Fatalf("expected 2 Sunset header values, got %v", got) + } + if got := w.Header().Values("X-API-Warn"); len(got) != 2 { + t.Fatalf("expected 2 X-API-Warn header values, got %v", got) + } + if got := w.Header().Values("Link"); len(got) != 2 { + t.Fatalf("expected 2 Link header values, got %v", got) + } +} diff --git a/swagger.go b/swagger.go index 65b45c5..36ec01b 100644 --- a/swagger.go +++ b/swagger.go @@ -4,6 +4,8 @@ package api import ( "fmt" + "net/http" + "strings" "sync" "sync/atomic" @@ -11,12 +13,16 @@ import ( swaggerFiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" "github.com/swaggo/swag" + "slices" ) // swaggerSeq provides unique instance names so multiple Engine instances // (common in tests) do not collide in the global swag registry. var swaggerSeq atomic.Uint64 +// defaultSwaggerPath is the URL path where the Swagger UI is mounted. +const defaultSwaggerPath = "/swagger" + // swaggerSpec wraps SpecBuilder to satisfy the swag.Spec interface. // The spec is built once on first access and cached. type swaggerSpec struct { @@ -26,6 +32,15 @@ type swaggerSpec struct { doc string } +var _ swag.Swagger = (*swaggerSpec)(nil) + +func newSwaggerSpec(builder *SpecBuilder, groups []RouteGroup) *swaggerSpec { + return &swaggerSpec{ + builder: builder, + groups: slices.Clone(groups), + } +} + // ReadDoc returns the OpenAPI 3.1 JSON document for this spec. func (s *swaggerSpec) ReadDoc() string { s.once.Do(func() { @@ -40,16 +55,38 @@ func (s *swaggerSpec) ReadDoc() string { } // registerSwagger mounts the Swagger UI and doc.json endpoint. -func registerSwagger(g *gin.Engine, title, description, version string, groups []RouteGroup) { - spec := &swaggerSpec{ - builder: &SpecBuilder{ - Title: title, - Description: description, - Version: version, - }, - groups: groups, - } +func registerSwagger(g *gin.Engine, e *Engine, groups []RouteGroup) { + swaggerPath := resolveSwaggerPath(e.swaggerPath) + spec := newSwaggerSpec(e.OpenAPISpecBuilder(), groups) name := fmt.Sprintf("swagger_%d", swaggerSeq.Add(1)) swag.Register(name, spec) - g.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name))) + g.GET(swaggerPath, func(c *gin.Context) { + c.Redirect(http.StatusMovedPermanently, swaggerPath+"/") + }) + g.GET(swaggerPath+"/*any", ginSwagger.WrapHandler(swaggerFiles.NewHandler(), ginSwagger.InstanceName(name))) +} + +// normaliseSwaggerPath coerces custom Swagger paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseSwaggerPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultSwaggerPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultSwaggerPath + } + + return path +} + +// resolveSwaggerPath returns the configured Swagger path or the default path +// when no override has been provided. +func resolveSwaggerPath(path string) string { + if strings.TrimSpace(path) == "" { + return defaultSwaggerPath + } + return normaliseSwaggerPath(path) } diff --git a/swagger_internal_test.go b/swagger_internal_test.go new file mode 100644 index 0000000..70260b7 --- /dev/null +++ b/swagger_internal_test.go @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import ( + "encoding/json" + "testing" + + "github.com/gin-gonic/gin" +) + +type swaggerSnapshotGroup struct { + name string + basePath string + descs []RouteDescription +} + +func (g *swaggerSnapshotGroup) Name() string { return g.name } +func (g *swaggerSnapshotGroup) BasePath() string { return g.basePath } +func (g *swaggerSnapshotGroup) RegisterRoutes(_ *gin.RouterGroup) {} +func (g *swaggerSnapshotGroup) Describe() []RouteDescription { + return g.descs +} + +func TestSwaggerSpec_ReadDoc_Good_SnapshotsGroups(t *testing.T) { + original := &swaggerSnapshotGroup{ + name: "first", + basePath: "/first", + descs: []RouteDescription{ + { + Method: "GET", + Path: "/ping", + Summary: "First group", + Response: map[string]any{ + "type": "string", + }, + }, + }, + } + replacement := &swaggerSnapshotGroup{ + name: "second", + basePath: "/second", + descs: []RouteDescription{ + { + Method: "GET", + Path: "/pong", + Summary: "Second group", + Response: map[string]any{ + "type": "string", + }, + }, + }, + } + + groups := []RouteGroup{original} + spec := newSwaggerSpec(&SpecBuilder{ + Title: "Test", + Version: "1.0.0", + }, groups) + + groups[0] = replacement + + var doc map[string]any + if err := json.Unmarshal([]byte(spec.ReadDoc()), &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := doc["paths"].(map[string]any) + if _, ok := paths["/first/ping"]; !ok { + t.Fatal("expected original group path to remain in the spec") + } + if _, ok := paths["/second/pong"]; ok { + t.Fatal("did not expect mutated group path to leak into the spec") + } +} diff --git a/swagger_test.go b/swagger_test.go index 636f89f..77c820b 100644 --- a/swagger_test.go +++ b/swagger_test.go @@ -65,6 +65,117 @@ func TestSwaggerEndpoint_Good(t *testing.T) { } } +func TestSwaggerEndpoint_Good_CustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/docs/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + if len(body) == 0 { + t.Fatal("expected non-empty response body") + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("expected valid JSON, got unmarshal error: %v", err) + } + + info, ok := doc["info"].(map[string]any) + if !ok { + t.Fatal("expected 'info' object in swagger doc") + } + if info["title"] != "Test API" { + t.Fatalf("expected title=%q, got %q", "Test API", info["title"]) + } +} + +func TestSwaggerEndpoint_Good_BasePathRedirect(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithSwagger("Test API", "A test API service", "1.0.0")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(srv.URL + "/swagger") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMovedPermanently { + t.Fatalf("expected 301 redirect, got %d", resp.StatusCode) + } + if got := resp.Header.Get("Location"); got != "/swagger/" { + t.Fatalf("expected Location=/swagger/, got %q", got) + } +} + +func TestSwaggerEndpoint_Good_CustomBasePathRedirect(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Get(srv.URL + "/docs") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusMovedPermanently { + t.Fatalf("expected 301 redirect, got %d", resp.StatusCode) + } + if got := resp.Header.Get("Location"); got != "/docs/" { + t.Fatalf("expected Location=/docs/, got %q", got) + } +} + func TestSwaggerDisabledByDefault_Good(t *testing.T) { gin.SetMode(gin.TestMode) @@ -81,6 +192,32 @@ func TestSwaggerDisabledByDefault_Good(t *testing.T) { } } +func TestSwaggerAuth_Good_CustomPathBypassesBearerAuth(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithBearerAuth("secret"), + api.WithSwagger("Test API", "A test API service", "1.0.0"), + api.WithSwaggerPath("/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/docs/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 for custom swagger path without auth, got %d", resp.StatusCode) + } +} + func TestSwagger_Good_SpecNotEmpty(t *testing.T) { gin.SetMode(gin.TestMode) @@ -200,6 +337,87 @@ func TestSwagger_Good_WithToolBridge(t *testing.T) { } } +func TestSwagger_Good_IncludesSSEEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New(api.WithSwagger("SSE API", "SSE test", "1.0.0"), api.WithSSE(broker)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := doc["paths"].(map[string]any) + pathItem, ok := paths["/events"].(map[string]any) + if !ok { + t.Fatal("expected /events path in swagger doc") + } + + getOp := pathItem["get"].(map[string]any) + if getOp["operationId"] != "get_events" { + t.Fatalf("expected SSE operationId to be get_events, got %v", getOp["operationId"]) + } +} + +func TestSwagger_Good_UsesCustomSSEPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + broker := api.NewSSEBroker() + e, err := api.New( + api.WithSwagger("SSE API", "SSE test", "1.0.0"), + api.WithSSE(broker), + api.WithSSEPath("/stream"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths := doc["paths"].(map[string]any) + if _, ok := paths["/stream"]; !ok { + t.Fatal("expected custom SSE path /stream in swagger doc") + } + if _, ok := paths["/events"]; ok { + t.Fatal("did not expect default /events path when custom SSE path is configured") + } +} + func TestSwagger_Good_CachesSpec(t *testing.T) { spec := &swaggerSpecHelper{ title: "Cache Test", @@ -258,6 +476,389 @@ func TestSwagger_Good_InfoFromOptions(t *testing.T) { } } +func TestSwagger_Good_IncludesGraphQLEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New(api.WithGraphQL(newTestSchema()), api.WithSwagger("Graph API", "GraphQL docs", "1.0.0")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + paths, ok := doc["paths"].(map[string]any) + if !ok { + t.Fatal("expected paths object in swagger doc") + } + if _, ok := paths["/graphql"]; !ok { + t.Fatal("expected /graphql path in swagger doc") + } +} + +func TestSwagger_Good_UsesLicenseMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Licensed API", "Licensed test", "1.0.0"), + api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := doc["info"].(map[string]any) + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected license metadata in swagger doc") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected license name=%q, got %v", "EUPL-1.2", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected license url=%q, got %v", "https://eupl.eu/1.2/en/", license["url"]) + } +} + +func TestSwagger_Good_UsesContactMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Contact API", "Contact test", "1.0.0"), + api.WithSwaggerContact("API Support", "https://example.com/support", "support@example.com"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := doc["info"].(map[string]any) + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in swagger doc") + } + if contact["name"] != "API Support" { + t.Fatalf("expected contact name=%q, got %v", "API Support", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected contact url=%q, got %v", "https://example.com/support", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected contact email=%q, got %v", "support@example.com", contact["email"]) + } +} + +func TestSwagger_Good_UsesTermsOfServiceMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Terms API", "Terms test", "1.0.0"), + api.WithSwaggerTermsOfService("https://example.com/terms"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := doc["info"].(map[string]any) + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService=%q, got %v", "https://example.com/terms", info["termsOfService"]) + } +} + +func TestSwagger_Good_UsesExternalDocsMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Docs API", "Docs test", "1.0.0"), + api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + externalDocs, ok := doc["externalDocs"].(map[string]any) + if !ok { + t.Fatal("expected externalDocs metadata in swagger doc") + } + if externalDocs["description"] != "Developer guide" { + t.Fatalf("expected externalDocs description=%q, got %v", "Developer guide", externalDocs["description"]) + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected externalDocs url=%q, got %v", "https://example.com/docs", externalDocs["url"]) + } +} + +func TestSwagger_Good_IgnoresBlankMetadataOverrides(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Stable API", "Blank override test", "1.0.0"), + api.WithSwaggerTermsOfService("https://example.com/terms"), + api.WithSwaggerTermsOfService(""), + api.WithSwaggerContact("API Support", "https://example.com/support", "support@example.com"), + api.WithSwaggerContact("", "", ""), + api.WithSwaggerLicense("EUPL-1.2", "https://eupl.eu/1.2/en/"), + api.WithSwaggerLicense("", ""), + api.WithSwaggerExternalDocs("Developer guide", "https://example.com/docs"), + api.WithSwaggerExternalDocs("", ""), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + info := doc["info"].(map[string]any) + if info["termsOfService"] != "https://example.com/terms" { + t.Fatalf("expected termsOfService to survive blank override, got %v", info["termsOfService"]) + } + + contact, ok := info["contact"].(map[string]any) + if !ok { + t.Fatal("expected contact metadata in swagger doc") + } + if contact["name"] != "API Support" { + t.Fatalf("expected contact name to survive blank override, got %v", contact["name"]) + } + if contact["url"] != "https://example.com/support" { + t.Fatalf("expected contact url to survive blank override, got %v", contact["url"]) + } + if contact["email"] != "support@example.com" { + t.Fatalf("expected contact email to survive blank override, got %v", contact["email"]) + } + + license, ok := info["license"].(map[string]any) + if !ok { + t.Fatal("expected license metadata in swagger doc") + } + if license["name"] != "EUPL-1.2" { + t.Fatalf("expected license name to survive blank override, got %v", license["name"]) + } + if license["url"] != "https://eupl.eu/1.2/en/" { + t.Fatalf("expected license url to survive blank override, got %v", license["url"]) + } + + externalDocs, ok := doc["externalDocs"].(map[string]any) + if !ok { + t.Fatal("expected externalDocs metadata in swagger doc") + } + if externalDocs["description"] != "Developer guide" { + t.Fatalf("expected externalDocs description to survive blank override, got %v", externalDocs["description"]) + } + if externalDocs["url"] != "https://example.com/docs" { + t.Fatalf("expected externalDocs url to survive blank override, got %v", externalDocs["url"]) + } +} + +func TestSwagger_Good_UsesServerMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Server API", "Server metadata test", "1.0.0"), + api.WithSwaggerServers(" https://api.example.com ", "/", "", "https://api.example.com"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + servers, ok := doc["servers"].([]any) + if !ok { + t.Fatalf("expected servers array, got %T", doc["servers"]) + } + if len(servers) != 2 { + t.Fatalf("expected 2 normalised servers, got %d", len(servers)) + } + + first := servers[0].(map[string]any) + if first["url"] != "https://api.example.com" { + t.Fatalf("expected first server url=%q, got %v", "https://api.example.com", first["url"]) + } + + second := servers[1].(map[string]any) + if second["url"] != "/" { + t.Fatalf("expected second server url=%q, got %v", "/", second["url"]) + } +} + +func TestSwagger_Good_AppendsServerMetadataAcrossCalls(t *testing.T) { + gin.SetMode(gin.TestMode) + + e, err := api.New( + api.WithSwagger("Server API", "Server metadata test", "1.0.0"), + api.WithSwaggerServers("https://api.example.com", "/"), + api.WithSwaggerServers(" https://docs.example.com ", "/", "https://api.example.com"), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/swagger/doc.json") + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read body: %v", err) + } + + var doc map[string]any + if err := json.Unmarshal(body, &doc); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + + servers, ok := doc["servers"].([]any) + if !ok { + t.Fatalf("expected servers array, got %T", doc["servers"]) + } + if len(servers) != 3 { + t.Fatalf("expected 3 normalised servers, got %d", len(servers)) + } + + expected := []string{"https://api.example.com", "/", "https://docs.example.com"} + for i, want := range expected { + got := servers[i].(map[string]any)["url"] + if got != want { + t.Fatalf("expected server[%d] url=%q, got %v", i, want, got) + } + } +} + func TestSwagger_Good_ValidOpenAPI(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/timeout_test.go b/timeout_test.go index c0e99a8..7630712 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -148,15 +148,8 @@ func TestWithTimeout_Good_CombinesWithOtherMiddleware(t *testing.T) { } func TestWithTimeout_Ugly_ZeroDurationDoesNotPanic(t *testing.T) { - skipIfRaceDetector(t) gin.SetMode(gin.TestMode) - defer func() { - if r := recover(); r != nil { - t.Fatalf("WithTimeout(0) panicked: %v", r) - } - }() - e, err := api.New(api.WithTimeout(0)) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -168,5 +161,15 @@ func TestWithTimeout_Ugly_ZeroDurationDoesNotPanic(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "/stub/ping", nil) h.ServeHTTP(w, req) - // We only care that it did not panic. Status may vary with zero timeout. + if w.Code != http.StatusOK { + t.Fatalf("expected 200 with zero timeout disabled, got %d", w.Code) + } + + var resp api.Response[string] + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal error: %v", err) + } + if resp.Data != "pong" { + t.Fatalf("expected Data=%q, got %q", "pong", resp.Data) + } } diff --git a/tracing.go b/tracing.go index 5fecb2f..01c2d95 100644 --- a/tracing.go +++ b/tracing.go @@ -24,6 +24,10 @@ import ( // otel.SetTextMapPropagator(propagation.TraceContext{}) // // engine, _ := api.New(api.WithTracing("my-service")) +// +// Example: +// +// api.New(api.WithTracing("my-service")) func WithTracing(serviceName string, opts ...otelgin.Option) Option { return func(e *Engine) { e.middlewares = append(e.middlewares, otelgin.Middleware(serviceName, opts...)) @@ -37,6 +41,11 @@ func WithTracing(serviceName string, opts ...otelgin.Option) Option { // This is a convenience helper for tests and simple deployments. // Production setups should build their own TracerProvider with batching, // resource attributes, and appropriate exporters. +// +// Example: +// +// tp := api.NewTracerProvider(exporter) +// _ = tp.Shutdown(context.Background()) func NewTracerProvider(exporter sdktrace.SpanExporter) *sdktrace.TracerProvider { tp := sdktrace.NewTracerProvider( sdktrace.WithSyncer(exporter), diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..a090729 --- /dev/null +++ b/transport.go @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package api + +import "strings" + +// TransportConfig captures the configured transport endpoints and flags for an Engine. +// +// It is intentionally small and serialisable so callers can inspect the active HTTP +// surface without rebuilding an OpenAPI document. +// +// Example: +// +// cfg := api.TransportConfig{SwaggerPath: "/swagger", WSPath: "/ws"} +type TransportConfig struct { + SwaggerEnabled bool + SwaggerPath string + GraphQLPath string + GraphQLEnabled bool + GraphQLPlayground bool + GraphQLPlaygroundPath string + WSEnabled bool + WSPath string + SSEEnabled bool + SSEPath string + PprofEnabled bool + ExpvarEnabled bool +} + +// TransportConfig returns the currently configured transport metadata for the engine. +// +// The result snapshots the Engine state at call time and normalises any configured +// URL paths using the same rules as the runtime handlers. +// +// Example: +// +// cfg := engine.TransportConfig() +func (e *Engine) TransportConfig() TransportConfig { + if e == nil { + return TransportConfig{} + } + + cfg := TransportConfig{ + SwaggerEnabled: e.swaggerEnabled, + WSEnabled: e.wsHandler != nil, + SSEEnabled: e.sseBroker != nil, + PprofEnabled: e.pprofEnabled, + ExpvarEnabled: e.expvarEnabled, + } + gql := e.GraphQLConfig() + cfg.GraphQLEnabled = gql.Enabled + cfg.GraphQLPlayground = gql.Playground + cfg.GraphQLPlaygroundPath = gql.PlaygroundPath + + if e.swaggerEnabled || strings.TrimSpace(e.swaggerPath) != "" { + cfg.SwaggerPath = resolveSwaggerPath(e.swaggerPath) + } + if gql.Path != "" { + cfg.GraphQLPath = gql.Path + } + if e.wsHandler != nil || strings.TrimSpace(e.wsPath) != "" { + cfg.WSPath = resolveWSPath(e.wsPath) + } + if e.sseBroker != nil || strings.TrimSpace(e.ssePath) != "" { + cfg.SSEPath = resolveSSEPath(e.ssePath) + } + + return cfg +} diff --git a/websocket.go b/websocket.go index 8eb7a33..fc5bedc 100644 --- a/websocket.go +++ b/websocket.go @@ -4,14 +4,43 @@ package api import ( "net/http" + "strings" "github.com/gin-gonic/gin" ) -// wrapWSHandler adapts a standard http.Handler to a Gin handler for the /ws route. +// defaultWSPath is the URL path where the WebSocket endpoint is mounted. +const defaultWSPath = "/ws" + +// wrapWSHandler adapts a standard http.Handler to a Gin handler for the WebSocket route. // The underlying handler is responsible for upgrading the connection to WebSocket. func wrapWSHandler(h http.Handler) gin.HandlerFunc { return func(c *gin.Context) { h.ServeHTTP(c.Writer, c.Request) } } + +// normaliseWSPath coerces custom WebSocket paths into a stable form. +// The path always begins with a single slash and never ends with one. +func normaliseWSPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return defaultWSPath + } + + path = "/" + strings.Trim(path, "/") + if path == "/" { + return defaultWSPath + } + + return path +} + +// resolveWSPath returns the configured WebSocket path or the default path +// when no override has been provided. +func resolveWSPath(path string) string { + if strings.TrimSpace(path) == "" { + return defaultWSPath + } + return normaliseWSPath(path) +} diff --git a/websocket_test.go b/websocket_test.go index cbad161..5d950b4 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -78,6 +78,93 @@ func TestWSEndpoint_Good(t *testing.T) { } } +func TestWSEndpoint_Good_CustomPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + _ = conn.WriteMessage(websocket.TextMessage, []byte("custom")) + }) + + e, err := api.New(api.WithWSPath("/socket"), api.WithWSHandler(wsHandler)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/socket" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("failed to dial custom WebSocket: %v", err) + } + defer conn.Close() + + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read custom WebSocket message: %v", err) + } + if string(msg) != "custom" { + t.Fatalf("expected message=%q, got %q", "custom", string(msg)) + } +} + +func TestWSEndpoint_Good_WithResponseMeta(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + wsHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Logf("upgrade error: %v", err) + return + } + defer conn.Close() + _ = conn.WriteMessage(websocket.TextMessage, []byte("meta")) + }) + + e, err := api.New( + api.WithRequestID(), + api.WithResponseMeta(), + api.WithWSHandler(wsHandler), + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + srv := httptest.NewServer(e.Handler()) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/ws" + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + if resp != nil { + t.Fatalf("failed to dial WebSocket: %v (status=%d)", err, resp.StatusCode) + } + t.Fatalf("failed to dial WebSocket: %v", err) + } + defer conn.Close() + + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("failed to read message: %v", err) + } + if string(msg) != "meta" { + t.Fatalf("expected message=%q, got %q", "meta", string(msg)) + } +} + func TestNoWSHandler_Good(t *testing.T) { gin.SetMode(gin.TestMode)