diff --git a/cmd/server/handlers.go b/cmd/server/handlers.go index 899b5490..d7e820d2 100644 --- a/cmd/server/handlers.go +++ b/cmd/server/handlers.go @@ -85,12 +85,14 @@ func NewHandlers(deps *HandlerDeps) routes.Handlers { // CTEM Discovery - Network Services, State History & Relationships AssetService: handler.NewAssetServiceHandler(repos.AssetService, repos.Asset, v, log), AssetStateHistory: handler.NewAssetStateHistoryHandler(repos.AssetStateHistory, repos.Asset, v, log), - AssetRelationship: handler.NewAssetRelationshipHandler(svc.AssetRelationship, v, log), + AssetRelationship: handler.NewAssetRelationshipHandler(svc.AssetRelationship, v, log), + RelationshipSuggestion: handler.NewRelationshipSuggestionHandler(svc.RelationshipSuggestion, log), // Vulnerabilities & Exposures Vulnerability: vulnHandler, FindingActivity: handler.NewFindingActivityHandler(svc.FindingActivity, svc.Vulnerability, log), FindingActions: handler.NewFindingActionsHandler(svc.FindingActions, log), + JiraWebhook: handler.NewJiraWebhookHandler(svc.JiraSync, log), Exposure: handler.NewExposureHandler(svc.Exposure, svc.User, v, log), ThreatIntel: handler.NewThreatIntelHandler(svc.ThreatIntel, v, log), CredentialImport: handler.NewCredentialImportHandler(svc.CredentialImport, v, log), diff --git a/cmd/server/repositories.go b/cmd/server/repositories.go index 67bc98dd..a9548c84 100644 --- a/cmd/server/repositories.go +++ b/cmd/server/repositories.go @@ -23,7 +23,8 @@ type Repositories struct { ScopeSchedule *postgres.ScopeScheduleRepository AssetService *postgres.AssetServiceRepository // CTEM: Network services on assets AssetStateHistory *postgres.AssetStateHistoryRepository // CTEM: State change audit log - AssetRelationship *postgres.AssetRelationshipRepository // CTEM: Asset topology graph + AssetRelationship *postgres.AssetRelationshipRepository // CTEM: Asset topology graph + RelationshipSuggestion *postgres.RelationshipSuggestionRepository // CTEM: Relationship suggestions // Vulnerabilities & Findings Vulnerability *postgres.VulnerabilityRepository @@ -168,7 +169,8 @@ func NewRepositories(db *postgres.DB) *Repositories { ScopeSchedule: postgres.NewScopeScheduleRepository(db), AssetService: postgres.NewAssetServiceRepository(db), // CTEM: Network services AssetStateHistory: postgres.NewAssetStateHistoryRepository(db), // CTEM: State change audit - AssetRelationship: postgres.NewAssetRelationshipRepository(db), // CTEM: Asset topology graph + AssetRelationship: postgres.NewAssetRelationshipRepository(db), // CTEM: Asset topology graph + RelationshipSuggestion: postgres.NewRelationshipSuggestionRepository(db), // CTEM: Relationship suggestions // Vulnerabilities & Findings Vulnerability: postgres.NewVulnerabilityRepository(db), diff --git a/cmd/server/services.go b/cmd/server/services.go index 9c8fcc3a..b14f652b 100644 --- a/cmd/server/services.go +++ b/cmd/server/services.go @@ -55,8 +55,9 @@ type Services struct { Asset *app.AssetService AssetGroup *app.AssetGroupService AssetType *app.AssetTypeService - AssetRelationship *app.AssetRelationshipService - Scope *app.ScopeService + AssetRelationship *app.AssetRelationshipService + RelationshipSuggestion *app.RelationshipSuggestionService + Scope *app.ScopeService AttackSurface *app.AttackSurfaceService // Configuration (read-only system config) @@ -159,6 +160,9 @@ type Services struct { APIKey *app.APIKeyService Webhook *app.WebhookService + // Jira Bidirectional Sync + JiraSync *app.JiraSyncService + // AI Triage AITriage *app.AITriageService @@ -224,8 +228,9 @@ func NewServices(deps *ServiceDeps) (*Services, error) { s.AssetGroup = app.NewAssetGroupService(repos.AssetGroup, log) s.AssetType = app.NewAssetTypeService(repos.AssetType, repos.AssetTypeCat, log) s.Scope = app.NewScopeService(repos.ScopeTarget, repos.ScopeExcl, repos.ScopeSchedule, repos.Asset, log) - s.AttackSurface = app.NewAttackSurfaceService(repos.Asset, log) + s.AttackSurface = app.NewAttackSurfaceService(repos.Asset, repos.AssetRelationship, log) s.AssetRelationship = app.NewAssetRelationshipService(repos.AssetRelationship, repos.Asset, log) + s.RelationshipSuggestion = app.NewRelationshipSuggestionService(repos.RelationshipSuggestion, repos.Asset, repos.AssetRelationship, log) // Initialize finding source service (read-only system configuration) s.FindingSource = app.NewFindingSourceService(repos.FindingSource, repos.FindingSourceCat, log) @@ -353,6 +358,7 @@ func NewServices(deps *ServiceDeps) (*Services, error) { // Initialize API Key & Webhook services s.APIKey = app.NewAPIKeyService(repos.APIKey, log) s.Webhook = app.NewWebhookService(repos.Webhook, s.Encryptor, log) + s.JiraSync = app.NewJiraSyncService(repos.Finding, log) // Initialize integration & notification services s.Integration = app.NewIntegrationService(repos.Integration, repos.IntegrationSCMExt, s.Encryptor, log) @@ -462,6 +468,10 @@ func NewServices(deps *ServiceDeps) (*Services, error) { scan.WithProfileRepo(repos.ScanProfile), ) + // Wire verification scan trigger: allows FindingActionsService to launch targeted scans + // when a finding transitions to fix_applied and the user requests scan-based verification. + s.FindingActions.SetVerificationScanTrigger(app.NewVerificationScanTriggerAdapter(s.Scan)) + // Create adapters for pipeline sub-package pipelineAuditAdapter := app.NewPipelineAuditServiceAdapter(s.Audit) pipelineAgentSelectorAdapter := app.NewPipelineAgentSelectorAdapter(s.AgentSelector) diff --git a/cmd/server/workers.go b/cmd/server/workers.go index d31ebffe..be388f7a 100644 --- a/cmd/server/workers.go +++ b/cmd/server/workers.go @@ -212,6 +212,23 @@ func NewWorkers(deps *WorkerDeps) (*Workers, error) { }, )) + // Threat intel — daily EPSS + KEV refresh (fetches + persists to DB) + w.ControllerManager.Register(controller.NewThreatIntelRefreshController( + svc.ThreatIntel, + log.With("controller", "threat-intel-refresh"), + )) + + // Control test scheduler — daily sweep to mark stale detection coverage as overdue + w.ControllerManager.Register(controller.NewControlTestSchedulerController( + repos.ControlTest, + &controller.ControlTestSchedulerConfig{ + Interval: 24 * time.Hour, + StaleDays: 30, + BatchSize: 500, + Logger: log.With("controller", "control-test-scheduler"), + }, + )) + return w, nil } diff --git a/configs/relationship-types.yaml b/configs/relationship-types.yaml index 4bb9a8ae..d346e049 100644 --- a/configs/relationship-types.yaml +++ b/configs/relationship-types.yaml @@ -140,6 +140,9 @@ types: # Code hierarchy - sources: [repository] targets: [container_image] + # DNS hierarchy + - sources: [domain] + targets: [subdomain] - id: exposes category: attack_surface_mapping @@ -161,14 +164,13 @@ types: direct: Resolves To inverse: Resolved By description: >- - Literal DNS A/AAAA resolution — a domain resolves to an IP record - or a load balancer that owns that IP. STRICT semantic: target - MUST be the network endpoint, not the server that happens to own - the IP. For "this domain leads to this server / website" use - Exposes. For subdomain → parent domain or CNAME aliases use - Cname Of. + Literal DNS A/AAAA resolution — a domain or subdomain resolves to + an IP record or a load balancer that owns that IP. STRICT semantic: + target MUST be the network endpoint, not the server that happens + to own the IP. For "this domain leads to this server / website" + use Exposes. For subdomain → parent domain hierarchy use Contains. constraints: - - sources: [domain] + - sources: [domain, subdomain] targets: [ip_address, load_balancer] - id: cname_of @@ -176,12 +178,13 @@ types: direct: CNAME Of inverse: Has CNAME description: >- - DNS aliasing — this name is a CNAME for that name. Also used for - subdomain → parent domain logical relationships. Distinct from - Resolves To which captures the final IP/host record. + DNS CNAME aliasing — this name is a CNAME record pointing to + another name. Strictly for actual DNS CNAME records, NOT for + subdomain hierarchy (use Contains for that). Distinct from + Resolves To which captures the final A/AAAA IP record. constraints: - - sources: [domain] - targets: [domain] + - sources: [domain, subdomain] + targets: [domain, subdomain] # ---- Attack Path Analysis ---------------------------------------------- diff --git a/internal/app/adapters.go b/internal/app/adapters.go index e8d8ede1..cba5c811 100644 --- a/internal/app/adapters.go +++ b/internal/app/adapters.go @@ -320,3 +320,39 @@ func convertToPipelineValidationResult(r *ValidationResult) *pipeline.Validation Errors: errors, } } + +// ============================================================================= +// Verification Scan Trigger Adapter +// ============================================================================= + +// verificationScanTriggerAdapter adapts scan.Service to FindingActionsService.VerificationScanTrigger. +type verificationScanTriggerAdapter struct { + svc *scan.Service +} + +// NewVerificationScanTriggerAdapter creates an adapter that wraps scan.Service +// for use as a VerificationScanTrigger. +func NewVerificationScanTriggerAdapter(svc *scan.Service) VerificationScanTrigger { + return &verificationScanTriggerAdapter{svc: svc} +} + +// TriggerVerificationScan implements VerificationScanTrigger. +func (a *verificationScanTriggerAdapter) TriggerVerificationScan( + ctx context.Context, tenantID, createdBy, scannerName, workflowID string, targets []string, +) (pipelineRunID, scanID string, err error) { + result, err := a.svc.QuickScan(ctx, scan.QuickScanInput{ + TenantID: tenantID, + Targets: targets, + ScannerName: scannerName, + WorkflowID: workflowID, + CreatedBy: createdBy, + Tags: []string{"verification-scan"}, + Config: map[string]any{ + "trigger": "verification_scan", + }, + }) + if err != nil { + return "", "", err + } + return result.PipelineRunID, result.ScanID, nil +} diff --git a/internal/app/asset_service.go b/internal/app/asset_service.go index c5e04504..7f62a705 100644 --- a/internal/app/asset_service.go +++ b/internal/app/asset_service.go @@ -397,9 +397,186 @@ func PromoteKnownProperties(input CreateAssetInput) CreateAssetInput { delete(input.Properties, key) } + // Normalize ALL camelCase property keys to snake_case. + // Collectors may send either convention; we standardize on snake_case. + // Generic converter handles any camelCase key automatically. + normalizedProps := make(map[string]any, len(input.Properties)) + for key, val := range input.Properties { + snakeKey := camelToSnakeCase(key) + // If both camelCase and snake_case exist, prefer the snake_case value + if snakeKey != key { + if _, exists := normalizedProps[snakeKey]; exists { + continue // snake_case version already set, skip camelCase duplicate + } + } + normalizedProps[snakeKey] = val + } + input.Properties = normalizedProps + + // Extract DNS fields from nested domain.dns_records → flat properties + // Collector sends: {"domain": {"dns_records": [{"type":"A","value":"1.2.3.4","ttl":300}]}} + // UI reads flat: record_type, resolved_ip, cname_target, ttl, dns_record_types, resolved_ips + if domainObj, ok := input.Properties["domain"].(map[string]any); ok { + if records, ok := domainObj["dns_records"].([]any); ok && len(records) > 0 { + var recordTypes []string + var resolvedIPs []string + for _, r := range records { + rec, ok := r.(map[string]any) + if !ok { + continue + } + recType, _ := rec["type"].(string) + recValue, _ := rec["value"].(string) + if recType != "" { + recordTypes = append(recordTypes, recType) + } + if recValue != "" && (recType == "A" || recType == "AAAA") { + resolvedIPs = append(resolvedIPs, recValue) + } + } + // First record as primary + if first, ok := records[0].(map[string]any); ok { + if rt, _ := first["type"].(string); rt != "" { + input.Properties["record_type"] = rt + } + if rv, _ := first["value"].(string); rv != "" { + rt, _ := first["type"].(string) + if rt == "A" || rt == "AAAA" { + input.Properties["resolved_ip"] = rv + } else if rt == "CNAME" { + input.Properties["cname_target"] = rv + } + } + if ttl, ok := first["ttl"]; ok { + input.Properties["ttl"] = ttl + } + } + // Aggregates + if len(recordTypes) > 0 { + input.Properties["dns_record_types"] = strings.Join(unique(recordTypes), ", ") + } + if len(resolvedIPs) > 0 { + input.Properties["resolved_ips"] = strings.Join(unique(resolvedIPs), ", ") + } + input.Properties["dns_record_count"] = len(records) + } + } + + // Normalize root_domain (strip trailing dot) + if rd, ok := input.Properties["root_domain"].(string); ok && strings.HasSuffix(rd, ".") { + input.Properties["root_domain"] = strings.TrimSuffix(rd, ".") + } + + // Auto-detect subdomain: if type is "domain", check whether the name + // looks like a subdomain based on domain level analysis. + // Method 1: use root_domain property if provided by collector + // Method 2: compute from domain name structure (handles .com.vn, .co.uk, etc.) + if input.Type == "domain" { + cleanName := strings.TrimSuffix(input.Name, ".") + + // Method 1: collector provides root_domain + if rootDomain, ok := input.Properties["root_domain"].(string); ok && rootDomain != "" { + cleanRoot := strings.TrimSuffix(rootDomain, ".") + if cleanRoot != cleanName && strings.HasSuffix(cleanName, "."+cleanRoot) { + input.Type = "subdomain" + } + } + + // Method 2: compute domain level from name structure + // A root domain has exactly 1 label before the effective TLD + // e.g., "ipa.com.vn" = root (1 label "ipa" before "com.vn") + // "sub.ipa.com.vn" = subdomain (2 labels before "com.vn") + if input.Type == "domain" && isLikelySubdomain(cleanName) { + input.Type = "subdomain" + } + } + return input } +// camelToSnakeCase converts a camelCase or PascalCase string to snake_case. +// Examples: "cpuCores" → "cpu_cores", "memoryGB" → "memory_gb", "apiType" → "api_type" +// Already snake_case or lowercase strings pass through unchanged. +func camelToSnakeCase(s string) string { + if s == "" { + return s + } + var result []byte + for i, r := range s { + if r >= 'A' && r <= 'Z' { + // Insert underscore before uppercase if: + // - not the first character + // - AND (previous char is lowercase OR next char is lowercase) + // This handles: "memoryGB" → "memory_gb", "apiURL" → "api_url" + if i > 0 { + prev := s[i-1] + if prev >= 'a' && prev <= 'z' { + result = append(result, '_') + } else if prev >= 'A' && prev <= 'Z' && i+1 < len(s) && s[i+1] >= 'a' && s[i+1] <= 'z' { + result = append(result, '_') + } + } + result = append(result, byte(r-'A'+'a')) + } else { + result = append(result, byte(r)) + } + } + return string(result) +} + +// unique returns a deduplicated copy of a string slice, preserving order. +func unique(ss []string) []string { + seen := make(map[string]bool, len(ss)) + out := make([]string, 0, len(ss)) + for _, s := range ss { + if !seen[s] { + seen[s] = true + out = append(out, s) + } + } + return out +} + +// isLikelySubdomain checks if a domain name has more labels than a typical root domain. +// Uses known second-level domains (.com.vn, .co.uk, .com.au, etc.) to determine +// the effective TLD length. If there are >1 labels before the eTLD, it's a subdomain. +func isLikelySubdomain(name string) bool { + parts := strings.Split(name, ".") + if len(parts) < 3 { + return false // "example.com" = 2 parts, not a subdomain + } + + // Known second-level TLDs (eTLD+1 has 3+ parts for root domains) + knownSLDs := map[string]bool{ + "com.vn": true, "net.vn": true, "org.vn": true, "edu.vn": true, "gov.vn": true, + "co.uk": true, "org.uk": true, "ac.uk": true, + "com.au": true, "net.au": true, "org.au": true, + "co.jp": true, "or.jp": true, "ac.jp": true, + "co.kr": true, "or.kr": true, + "com.br": true, "org.br": true, + "co.in": true, "org.in": true, "net.in": true, + "com.sg": true, "org.sg": true, + "com.my": true, "org.my": true, + "co.th": true, "or.th": true, + "com.tw": true, "org.tw": true, + "co.id": true, "or.id": true, + "com.ph": true, "org.ph": true, + "co.nz": true, "org.nz": true, + "co.za": true, "org.za": true, + } + + // Check if last 2 parts form a known SLD + last2 := strings.Join(parts[len(parts)-2:], ".") + if knownSLDs[last2] { + // For .com.vn: "ipa.com.vn" = 3 parts = root; "sub.ipa.com.vn" = 4 parts = subdomain + return len(parts) > 3 + } + + // For simple TLDs (.com, .net, .org, .io, etc.): + // "example.com" = 2 parts = root; "sub.example.com" = 3 parts = subdomain + return len(parts) > 2 +} + // correlateByIPOrHostname tries to find an existing asset by IP or hostname properties. // If input.Name looks like an IP (e.g., "10.0.1.5"), search for hosts with that IP in properties. // If input.Name looks like a hostname, search for IP-named assets with that hostname in properties. @@ -937,15 +1114,15 @@ func (s *AssetService) GetPropertyFacets(ctx context.Context, tenantID string, t // Supports prefix filtering for autocomplete. // GetAssetStats returns aggregated asset statistics using SQL aggregation. // Filters: types (asset_type ANY), tags (overlap, matches List semantics). -func (s *AssetService) GetAssetStats(ctx context.Context, tenantID string, types []string, tags []string, subType string) (*asset.AggregateStats, error) { +func (s *AssetService) GetAssetStats(ctx context.Context, tenantID string, types []string, tags []string, subType string, countByFields ...string) (*asset.AggregateStats, error) { parsedTenantID, err := shared.IDFromString(tenantID) if err != nil { return nil, fmt.Errorf("%w: invalid tenant id format", shared.ErrValidation) } - return s.repo.GetAggregateStats(ctx, parsedTenantID, types, tags, subType) + return s.repo.GetAggregateStats(ctx, parsedTenantID, types, tags, subType, countByFields...) } -func (s *AssetService) ListTags(ctx context.Context, tenantID string, prefix string, limit int) ([]string, error) { +func (s *AssetService) ListTags(ctx context.Context, tenantID string, prefix string, types []string, limit int) ([]string, error) { parsedTenantID, err := shared.IDFromString(tenantID) if err != nil { return nil, fmt.Errorf("%w: invalid tenant id format", shared.ErrValidation) @@ -961,7 +1138,7 @@ func (s *AssetService) ListTags(ctx context.Context, tenantID string, prefix str prefix = prefix[:50] } - return s.repo.ListDistinctTags(ctx, parsedTenantID, prefix, limit) + return s.repo.ListDistinctTags(ctx, parsedTenantID, prefix, types, limit) } // ActivateAsset activates an asset. diff --git a/internal/app/attack_path_scoring.go b/internal/app/attack_path_scoring.go new file mode 100644 index 00000000..6bab868e --- /dev/null +++ b/internal/app/attack_path_scoring.go @@ -0,0 +1,284 @@ +package app + +import ( + "context" + "fmt" + "sort" + + "github.com/openctemio/api/pkg/domain/asset" + "github.com/openctemio/api/pkg/domain/shared" +) + +// attackPathRelationshipTypes are the relationship types that represent +// lateral movement and attack progression. We traverse ONLY these types +// when computing reachability from public entry points. +var attackPathRelationshipTypes = map[asset.RelationshipType]bool{ + asset.RelTypeRunsOn: true, + asset.RelTypeDeployedTo: true, + asset.RelTypeContains: true, + asset.RelTypeExposes: true, + asset.RelTypeResolvesTo: true, + asset.RelTypeDependsOn: true, + asset.RelTypeSendsDataTo: true, + asset.RelTypeStoresDataIn: true, + asset.RelTypeAuthenticatesTo: true, + asset.RelTypeGrantedTo: true, + asset.RelTypeHasAccessTo: true, + asset.RelTypeLoadBalances: true, +} + +// controlRelationshipTypes are relationships that indicate security controls. +// Assets protected by these add a "protected" flag to scored nodes. +var controlRelationshipTypes = map[asset.RelationshipType]bool{ + asset.RelTypeProtectedBy: true, + asset.RelTypeMonitors: true, +} + +// AssetPathScore holds the computed attack path score for a single asset. +type AssetPathScore struct { + // AssetID is the UUID of the asset. + AssetID string `json:"asset_id"` + // Name is the human-readable name of the asset. + Name string `json:"name"` + // AssetType is the type of the asset (e.g., "host", "application"). + AssetType string `json:"asset_type"` + // Exposure is the asset's configured exposure level. + Exposure string `json:"exposure"` + // Criticality is the asset's criticality level. + Criticality string `json:"criticality"` + // RiskScore is the asset-level risk score (1-10). + RiskScore int `json:"risk_score"` + // IsCrownJewel marks high-value target assets. + IsCrownJewel bool `json:"is_crown_jewel"` + // FindingCount is the number of open findings on this asset. + FindingCount int `json:"finding_count"` + + // ReachableFrom is the count of distinct public entry points that can + // reach this asset following attack-path relationship types. + ReachableFrom int `json:"reachable_from"` + // PathScore is the composite attack path score: + // (reachable_from * impact_weight) where impact_weight = risk_score * criticality_multiplier + // Higher = more urgent to remediate. + PathScore float64 `json:"path_score"` + // IsEntryPoint is true when this asset is itself a public entry point. + IsEntryPoint bool `json:"is_entry_point"` + // IsProtected is true when the asset has at least one "protected_by" or "monitors" relationship. + IsProtected bool `json:"is_protected"` +} + +// AttackPathSummary holds aggregate attack path metrics for the tenant. +type AttackPathSummary struct { + // TotalPaths is the total number of directed attack paths discovered + // (entry point → reachable asset pairs). + TotalPaths int `json:"total_paths"` + // EntryPoints is the count of public-exposure assets that act as entry points. + EntryPoints int `json:"entry_points"` + // ReachableAssets is the count of non-public assets reachable from at least one entry point. + ReachableAssets int `json:"reachable_assets"` + // MaxDepth is the longest BFS chain found. + MaxDepth int `json:"max_depth"` + // CriticalReachable is the count of critical/high assets reachable from entry points. + CriticalReachable int `json:"critical_reachable"` + // CrownJewelsAtRisk is the count of crown-jewel assets reachable from entry points. + CrownJewelsAtRisk int `json:"crown_jewels_at_risk"` + // HasRelationshipData indicates whether the tenant has any relationship data at all. + HasRelationshipData bool `json:"has_relationship_data"` +} + +// AttackPathScoringResult is the full result returned by ComputeAttackPathScores. +type AttackPathScoringResult struct { + Summary AttackPathSummary `json:"summary"` + // TopAssets is the ranked list of assets by PathScore (descending), limited to 50. + TopAssets []AssetPathScore `json:"top_assets"` +} + +// criticalityMultiplier returns an impact multiplier for a criticality level. +func criticalityMultiplier(criticality string) float64 { + switch criticality { + case "critical": + return 4.0 + case "high": + return 3.0 + case "medium": + return 2.0 + case "low": + return 1.0 + default: + return 1.0 + } +} + +// ComputeAttackPathScores performs in-memory attack path analysis for the +// tenant. It: +// 1. Loads all assets (nodes) and relationships (edges) +// 2. Identifies public-exposure assets as entry points +// 3. Runs BFS from each entry point following attack-path edges +// 4. Counts how many entry points can reach each internal asset +// 5. Computes a composite PathScore from reachability + risk + criticality +// 6. Returns top 50 assets by PathScore + aggregate summary +func (s *AttackSurfaceService) ComputeAttackPathScores( + ctx context.Context, + tenantID shared.ID, + relRepo asset.RelationshipRepository, +) (*AttackPathScoringResult, error) { + // Step 1 — load nodes + nodes, err := s.assetRepo.ListAllNodes(ctx, tenantID) + if err != nil { + return nil, fmt.Errorf("load nodes: %w", err) + } + + // Step 2 — load edges + edges, err := relRepo.ListAllEdges(ctx, tenantID) + if err != nil { + return nil, fmt.Errorf("load edges: %w", err) + } + + // Build lookup maps + nodeByID := make(map[string]*asset.AssetNode, len(nodes)) + for i := range nodes { + nodeByID[nodes[i].ID] = &nodes[i] + } + + // Build adjacency list (directed: source → targets) for attack-path edges + adj := make(map[string][]string, len(nodes)) + // Track which assets are "protected" + protected := make(map[string]bool) + + for _, e := range edges { + if controlRelationshipTypes[e.Type] { + // source asset is protected (protected_by / monitors points FROM protected TO control) + // convention: "A protected_by B" means A is target, B is source when edge direction + // is stored as source=A → target=B? Let's treat target as the protected asset. + protected[e.TargetAssetID] = true + continue + } + if attackPathRelationshipTypes[e.Type] { + adj[e.SourceAssetID] = append(adj[e.SourceAssetID], e.TargetAssetID) + } + } + + // Step 3 — identify public entry points + entryPoints := make([]string, 0) + for _, n := range nodes { + if n.Exposure == "public" { + entryPoints = append(entryPoints, n.ID) + } + } + + // Step 4 — BFS from each entry point, counting reachability per node + reachableFrom := make(map[string]int, len(nodes)) // assetID → count of entry points that can reach it + maxDepth := 0 + + for _, ep := range entryPoints { + visited := make(map[string]bool) + visited[ep] = true + queue := []struct { + id string + depth int + }{{ep, 0}} + + for len(queue) > 0 { + cur := queue[0] + queue = queue[1:] + + if cur.depth > maxDepth { + maxDepth = cur.depth + } + + for _, neighborID := range adj[cur.id] { + if visited[neighborID] { + continue + } + visited[neighborID] = true + // Only count non-public assets as "reachable internal assets" + if n, ok := nodeByID[neighborID]; ok && n.Exposure != "public" { + reachableFrom[neighborID]++ + } + queue = append(queue, struct { + id string + depth int + }{neighborID, cur.depth + 1}) + } + } + } + + // Step 5 — compute scores and build result + totalPaths := 0 + reachableSet := make(map[string]bool) + criticalReachable := 0 + crownJewelsAtRisk := 0 + + scored := make([]AssetPathScore, 0, len(nodes)) + for _, n := range nodes { + rc := reachableFrom[n.ID] + isEntry := n.Exposure == "public" + + if rc > 0 { + reachableSet[n.ID] = true + totalPaths += rc + if n.Criticality == "critical" || n.Criticality == "high" { + criticalReachable++ + } + if n.IsCrownJewel { + crownJewelsAtRisk++ + } + } + + riskScore := n.RiskScore + if riskScore == 0 { + riskScore = 5 // default mid-range when not set + } + + pathScore := float64(rc) * float64(riskScore) * criticalityMultiplier(n.Criticality) + if n.FindingCount > 0 { + // Boost path score proportional to number of open findings + pathScore += float64(n.FindingCount) * 0.5 + } + + scored = append(scored, AssetPathScore{ + AssetID: n.ID, + Name: n.Name, + AssetType: n.AssetType, + Exposure: n.Exposure, + Criticality: n.Criticality, + RiskScore: n.RiskScore, + IsCrownJewel: n.IsCrownJewel, + FindingCount: n.FindingCount, + ReachableFrom: rc, + PathScore: pathScore, + IsEntryPoint: isEntry, + IsProtected: protected[n.ID], + }) + } + + // Sort by PathScore descending, then by ReachableFrom, then by name for stability + sort.Slice(scored, func(i, j int) bool { + if scored[i].PathScore != scored[j].PathScore { + return scored[i].PathScore > scored[j].PathScore + } + if scored[i].ReachableFrom != scored[j].ReachableFrom { + return scored[i].ReachableFrom > scored[j].ReachableFrom + } + return scored[i].Name < scored[j].Name + }) + + // Cap at 50 + if len(scored) > 50 { + scored = scored[:50] + } + + summary := AttackPathSummary{ + TotalPaths: totalPaths, + EntryPoints: len(entryPoints), + ReachableAssets: len(reachableSet), + MaxDepth: maxDepth, + CriticalReachable: criticalReachable, + CrownJewelsAtRisk: crownJewelsAtRisk, + HasRelationshipData: len(edges) > 0, + } + + return &AttackPathScoringResult{ + Summary: summary, + TopAssets: scored, + }, nil +} diff --git a/internal/app/attack_surface_service.go b/internal/app/attack_surface_service.go index 023362a8..9006630d 100644 --- a/internal/app/attack_surface_service.go +++ b/internal/app/attack_surface_service.go @@ -88,17 +88,24 @@ type AttackSurfaceStatsData struct { // AttackSurfaceService provides attack surface operations. type AttackSurfaceService struct { assetRepo asset.Repository + relRepo asset.RelationshipRepository logger *logger.Logger } // NewAttackSurfaceService creates a new AttackSurfaceService. -func NewAttackSurfaceService(assetRepo asset.Repository, log *logger.Logger) *AttackSurfaceService { +func NewAttackSurfaceService(assetRepo asset.Repository, relRepo asset.RelationshipRepository, log *logger.Logger) *AttackSurfaceService { return &AttackSurfaceService{ assetRepo: assetRepo, + relRepo: relRepo, logger: log.With("service", "attack_surface"), } } +// GetAttackPathScores computes attack path scoring for the tenant. +func (s *AttackSurfaceService) GetAttackPathScores(ctx context.Context, tenantID shared.ID) (*AttackPathScoringResult, error) { + return s.ComputeAttackPathScores(ctx, tenantID, s.relRepo) +} + // GetStats returns attack surface statistics for a tenant. func (s *AttackSurfaceService) GetStats(ctx context.Context, tenantID shared.ID) (*AttackSurfaceStats, error) { tenantIDStr := tenantID.String() diff --git a/internal/app/business_unit_service.go b/internal/app/business_unit_service.go index e9ada59d..222ae2f8 100644 --- a/internal/app/business_unit_service.go +++ b/internal/app/business_unit_service.go @@ -63,6 +63,39 @@ func (s *BusinessUnitService) List(ctx context.Context, tenantID string, filter return s.repo.List(ctx, filter, page) } +// UpdateBusinessUnitInput holds input for updating a BU. +type UpdateBusinessUnitInput struct { + TenantID string + ID string + Name string + Description string + OwnerName string + OwnerEmail string + Tags []string +} + +// Update updates an existing business unit. +func (s *BusinessUnitService) Update(ctx context.Context, input UpdateBusinessUnitInput) (*businessunit.BusinessUnit, error) { + tid, err := shared.IDFromString(input.TenantID) + if err != nil { + return nil, fmt.Errorf("%w: invalid tenant id", shared.ErrValidation) + } + bid, err := shared.IDFromString(input.ID) + if err != nil { + return nil, fmt.Errorf("%w: invalid business unit id", shared.ErrValidation) + } + bu, err := s.repo.GetByID(ctx, tid, bid) + if err != nil { + return nil, fmt.Errorf("failed to get business unit: %w", err) + } + bu.Update(input.Name, input.Description, input.OwnerName, input.OwnerEmail) + bu.SetTags(input.Tags) + if err := s.repo.Update(ctx, bu); err != nil { + return nil, fmt.Errorf("failed to update business unit: %w", err) + } + return bu, nil +} + // Delete deletes a BU. func (s *BusinessUnitService) Delete(ctx context.Context, tenantID, buID string) error { tid, _ := shared.IDFromString(tenantID) diff --git a/internal/app/finding_actions_service.go b/internal/app/finding_actions_service.go index ad826dec..48f493ed 100644 --- a/internal/app/finding_actions_service.go +++ b/internal/app/finding_actions_service.go @@ -15,6 +15,15 @@ import ( "github.com/openctemio/api/pkg/pagination" ) +// VerificationScanTrigger is the interface for triggering targeted verification scans. +// Implemented by the scan.Service; kept as interface to avoid import cycles. +type VerificationScanTrigger interface { + // TriggerVerificationScan launches a quick scan on the given targets. + // targets is a list of asset identifiers (names / hostnames / URLs). + // Returns the pipeline run ID and scan ID on success. + TriggerVerificationScan(ctx context.Context, tenantID, createdBy, scannerName, workflowID string, targets []string) (pipelineRunID, scanID string, err error) +} + // FindingActionsService handles the closed-loop finding lifecycle: // in_progress → fix_applied → resolved (verified by scan or security). type FindingActionsService struct { @@ -23,6 +32,7 @@ type FindingActionsService struct { groupRepo group.Repository assetRepo asset.Repository activityService *FindingActivityService + scanTrigger VerificationScanTrigger // optional; set via SetVerificationScanTrigger db *sql.DB logger *logger.Logger } @@ -48,6 +58,11 @@ func NewFindingActionsService( } } +// SetVerificationScanTrigger wires the scan trigger (called after both services are initialized). +func (s *FindingActionsService) SetVerificationScanTrigger(trigger VerificationScanTrigger) { + s.scanTrigger = trigger +} + // --- Group View --- // ListFindingGroups returns findings grouped by a dimension. @@ -556,6 +571,92 @@ func (s *FindingActionsService) AutoAssignToOwners( return result, nil } +// --- Verification Scan --- + +// RequestVerificationScanInput is the input for requesting a verification scan. +type RequestVerificationScanInput struct { + FindingID string + ScannerName string // required if WorkflowID is empty + WorkflowID string // required if ScannerName is empty +} + +// RequestVerificationScanResult is the result of requesting a verification scan. +type RequestVerificationScanResult struct { + FindingID string `json:"finding_id"` + AssetID string `json:"asset_id"` + AssetName string `json:"asset_name"` + PipelineRunID string `json:"pipeline_run_id"` + ScanID string `json:"scan_id"` +} + +// RequestVerificationScan triggers a targeted quick scan on the asset associated with a finding. +// The finding must be in fix_applied status (dev has marked it as fixed; awaiting scan verification). +// The scan result is expected to either confirm the fix (→ resolved) or reveal the vuln still exists +// (→ back to in_progress) via the normal ingest pipeline. +func (s *FindingActionsService) RequestVerificationScan( + ctx context.Context, tenantID, userID string, input RequestVerificationScanInput, +) (*RequestVerificationScanResult, error) { + if s.scanTrigger == nil { + return nil, fmt.Errorf("%w: verification scan trigger not configured", shared.ErrInternal) + } + + if input.ScannerName == "" && input.WorkflowID == "" { + return nil, fmt.Errorf("%w: scanner_name or workflow_id is required", shared.ErrValidation) + } + + tid, err := shared.IDFromString(tenantID) + if err != nil { + return nil, fmt.Errorf("%w: invalid tenant_id", shared.ErrValidation) + } + + fid, err := shared.IDFromString(input.FindingID) + if err != nil { + return nil, fmt.Errorf("%w: invalid finding_id", shared.ErrValidation) + } + + f, err := s.findingRepo.GetByID(ctx, tid, fid) + if err != nil { + return nil, fmt.Errorf("finding not found: %w", err) + } + + if f.Status() != vulnerability.FindingStatusFixApplied { + return nil, fmt.Errorf( + "%w: finding must be in fix_applied status to request verification scan (current: %s)", + shared.ErrValidation, f.Status(), + ) + } + + assetEntity, err := s.assetRepo.GetByID(ctx, tid, f.AssetID()) + if err != nil { + return nil, fmt.Errorf("asset not found for finding: %w", err) + } + + targets := []string{assetEntity.Name()} + + runID, scanID, err := s.scanTrigger.TriggerVerificationScan( + ctx, tenantID, userID, input.ScannerName, input.WorkflowID, targets, + ) + if err != nil { + return nil, fmt.Errorf("failed to trigger verification scan: %w", err) + } + + s.logger.Info("verification scan triggered", + "finding_id", f.ID(), + "asset_id", f.AssetID(), + "asset_name", assetEntity.Name(), + "pipeline_run_id", runID, + "scan_id", scanID, + ) + + return &RequestVerificationScanResult{ + FindingID: f.ID().String(), + AssetID: f.AssetID().String(), + AssetName: assetEntity.Name(), + PipelineRunID: runID, + ScanID: scanID, + }, nil +} + // --- Validation helpers --- var cveIDRegex = regexp.MustCompile(`^CVE-\d{4}-\d{4,}$`) diff --git a/internal/app/jira_sync_service.go b/internal/app/jira_sync_service.go new file mode 100644 index 00000000..ad0f3ad7 --- /dev/null +++ b/internal/app/jira_sync_service.go @@ -0,0 +1,267 @@ +package app + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/domain/vulnerability" + "github.com/openctemio/api/pkg/logger" +) + +// JiraSyncService handles bidirectional sync between findings and Jira tickets. +// +// Lifecycle: +// 1. POST /findings/{id}/link-ticket — records a Jira ticket key + URL on the finding. +// 2. POST /webhooks/incoming/jira — receives Jira status-change webhooks and updates +// the finding's status accordingly (in_progress / fix_applied / resolved). +// +// The integration is intentionally lightweight: no OAuth, no REST API calls from us. +// Jira pushes status changes to us; users link tickets manually (or via workflow). +type JiraSyncService struct { + findingRepo vulnerability.FindingRepository + logger *logger.Logger +} + +// NewJiraSyncService creates a new JiraSyncService. +func NewJiraSyncService(findingRepo vulnerability.FindingRepository, log *logger.Logger) *JiraSyncService { + return &JiraSyncService{ + findingRepo: findingRepo, + logger: log.With("service", "jira-sync"), + } +} + +// LinkTicketInput is the payload for linking a Jira ticket to a finding. +type LinkTicketInput struct { + TenantID string `json:"tenant_id"` + FindingID string `json:"finding_id"` + TicketKey string `json:"ticket_key"` // e.g. "PROJ-123" + TicketURL string `json:"ticket_url"` // e.g. "https://myorg.atlassian.net/browse/PROJ-123" +} + +// LinkTicket adds a Jira ticket reference to a finding's work_item_uris. +// Idempotent — re-adding the same URL is a no-op at the domain level. +func (s *JiraSyncService) LinkTicket(ctx context.Context, input LinkTicketInput) error { + if strings.TrimSpace(input.TicketKey) == "" { + return fmt.Errorf("%w: ticket_key is required", shared.ErrValidation) + } + if strings.TrimSpace(input.TicketURL) == "" { + return fmt.Errorf("%w: ticket_url is required", shared.ErrValidation) + } + + tenantID, err := shared.IDFromString(input.TenantID) + if err != nil { + return fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + + findingID, err := shared.IDFromString(input.FindingID) + if err != nil { + return fmt.Errorf("%w: invalid finding ID", shared.ErrValidation) + } + + finding, err := s.findingRepo.GetByID(ctx, tenantID, findingID) + if err != nil { + return fmt.Errorf("get finding: %w", err) + } + + // Add only if not already present (domain method is idempotent). + finding.AddWorkItemURI(input.TicketURL) + + if err := s.findingRepo.UpdateWorkItemURIs(ctx, tenantID, findingID, finding.WorkItemURIs()); err != nil { + return fmt.Errorf("persist ticket link: %w", err) + } + + s.logger.Info("jira ticket linked to finding", + "finding_id", findingID.String(), + "ticket_key", input.TicketKey, + "ticket_url", input.TicketURL, + ) + return nil +} + +// UnlinkTicket removes a Jira ticket reference from a finding. +func (s *JiraSyncService) UnlinkTicket(ctx context.Context, tenantIDStr, findingIDStr, ticketURL string) error { + tenantID, err := shared.IDFromString(tenantIDStr) + if err != nil { + return fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + findingID, err := shared.IDFromString(findingIDStr) + if err != nil { + return fmt.Errorf("%w: invalid finding ID", shared.ErrValidation) + } + + finding, err := s.findingRepo.GetByID(ctx, tenantID, findingID) + if err != nil { + return fmt.Errorf("get finding: %w", err) + } + + existing := finding.WorkItemURIs() + updated := make([]string, 0, len(existing)) + for _, u := range existing { + if u != ticketURL { + updated = append(updated, u) + } + } + + if err := s.findingRepo.UpdateWorkItemURIs(ctx, tenantID, findingID, updated); err != nil { + return fmt.Errorf("persist ticket unlink: %w", err) + } + + s.logger.Info("jira ticket unlinked from finding", + "finding_id", findingID.String(), + "ticket_url", ticketURL, + ) + return nil +} + +// JiraWebhookPayload is the envelope sent by Jira issue-updated webhooks. +// See https://developer.atlassian.com/cloud/jira/platform/webhooks/ +type JiraWebhookPayload struct { + WebhookEvent string `json:"webhookEvent"` // "jira:issue_updated" + Issue JiraWebhookIssue `json:"issue"` + Changelog *JiraChangelog `json:"changelog,omitempty"` +} + +// JiraWebhookIssue represents the issue block inside a Jira webhook payload. +type JiraWebhookIssue struct { + Key string `json:"key"` // e.g. "PROJ-123" + Self string `json:"self"` // e.g. "https://myorg.atlassian.net/rest/api/2/issue/10001" + Fields map[string]interface{} `json:"fields"` +} + +// JiraChangelog carries the before/after values of changed fields. +type JiraChangelog struct { + Items []JiraChangeItem `json:"items"` +} + +// JiraChangeItem is one entry in the changelog. +type JiraChangeItem struct { + Field string `json:"field"` + FromString string `json:"fromString"` + ToString string `json:"toString"` +} + +// HandleJiraWebhook processes an inbound Jira webhook and syncs finding status. +// +// Status mapping: +// +// Jira "In Progress" → finding "in_progress" +// Jira "In Review" → finding "in_progress" +// Jira "Done" → finding "fix_applied" (triggers verification flow) +// Jira "Resolved" → finding "fix_applied" +// Jira "Closed" → finding "fix_applied" +func (s *JiraSyncService) HandleJiraWebhook(ctx context.Context, tenantID shared.ID, payload JiraWebhookPayload) error { + if payload.Changelog == nil { + // No changes — nothing to sync. + return nil + } + + // Find the status transition in the changelog. + newJiraStatus := "" + for _, item := range payload.Changelog.Items { + if strings.EqualFold(item.Field, "status") { + newJiraStatus = item.ToString + break + } + } + if newJiraStatus == "" { + // Webhook is for a non-status change (field update, comment, etc.) — ignore. + return nil + } + + newFindingStatus, ok := mapJiraStatusToFinding(newJiraStatus) + if !ok { + s.logger.Debug("jira status has no finding mapping — ignored", + "jira_status", newJiraStatus, + "issue_key", payload.Issue.Key, + ) + return nil + } + + // Derive the ticket URL from the issue self link or fallback to Atlassian browse URL. + ticketURL := deriveJiraTicketURL(payload.Issue) + + // Look up finding by work item URI. + finding, err := s.findingRepo.GetByWorkItemURI(ctx, tenantID, ticketURL) + if err != nil { + if errors.Is(err, shared.ErrNotFound) { + // No finding linked to this ticket — silently ignore. + s.logger.Debug("no finding linked to jira ticket", + "ticket_url", ticketURL, + "issue_key", payload.Issue.Key, + ) + return nil + } + return fmt.Errorf("lookup finding by work item URI: %w", err) + } + + // Apply the status transition if valid. + if err := finding.TransitionStatus(newFindingStatus, "", nil); err != nil { + s.logger.Warn("jira webhook: finding status transition not allowed", + "finding_id", finding.ID().String(), + "current_status", finding.Status(), + "target_status", newFindingStatus, + "jira_status", newJiraStatus, + "error", err, + ) + // Not a hard error — the transition might be blocked (e.g., false_positive). + return nil + } + + if err := s.findingRepo.Update(ctx, finding); err != nil { + return fmt.Errorf("update finding status from jira webhook: %w", err) + } + + s.logger.Info("jira webhook synced finding status", + "finding_id", finding.ID().String(), + "issue_key", payload.Issue.Key, + "jira_status", newJiraStatus, + "finding_status", newFindingStatus, + ) + return nil +} + +// mapJiraStatusToFinding maps a Jira status name to a FindingStatus. +// Returns (status, true) when a mapping exists, (_, false) otherwise. +func mapJiraStatusToFinding(jiraStatus string) (vulnerability.FindingStatus, bool) { + normalized := strings.ToLower(strings.TrimSpace(jiraStatus)) + switch normalized { + case "in progress", "in review", "in development", "open": + return vulnerability.FindingStatusInProgress, true + case "done", "resolved", "closed", "completed", "fixed": + return vulnerability.FindingStatusFixApplied, true + case "to do", "backlog", "reopened": + return vulnerability.FindingStatusConfirmed, true + default: + return "", false + } +} + +// deriveJiraTicketURL builds the canonical browse URL for a Jira issue. +// It prefers payload.Issue.Self (REST API URL) but converts it to the browse URL +// so it matches what users paste when linking tickets. +func deriveJiraTicketURL(issue JiraWebhookIssue) string { + if issue.Self != "" { + // Convert REST API URL to browse URL: + // https://myorg.atlassian.net/rest/api/2/issue/10001 → + // https://myorg.atlassian.net/browse/PROJ-123 + // + // Split on "/rest/" and reconstruct. + if idx := strings.Index(issue.Self, "/rest/"); idx > 0 && issue.Key != "" { + base := issue.Self[:idx] + return base + "/browse/" + issue.Key + } + } + return issue.Key +} + +// JiraTicketInfo is returned by LinkTicket to the HTTP handler for the response body. +type JiraTicketInfo struct { + FindingID string `json:"finding_id"` + TicketKey string `json:"ticket_key"` + TicketURL string `json:"ticket_url"` + LinkedAt time.Time `json:"linked_at"` +} diff --git a/internal/app/relationship_suggestion_service.go b/internal/app/relationship_suggestion_service.go new file mode 100644 index 00000000..070c6391 --- /dev/null +++ b/internal/app/relationship_suggestion_service.go @@ -0,0 +1,441 @@ +package app + +import ( + "context" + "fmt" + "strings" + + "github.com/openctemio/api/pkg/domain/asset" + "github.com/openctemio/api/pkg/domain/relationship" + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/logger" + "github.com/openctemio/api/pkg/pagination" +) + +// RelationshipSuggestionService handles relationship suggestion business logic. +type RelationshipSuggestionService struct { + suggestionRepo relationship.SuggestionRepository + assetRepo asset.Repository + relRepo asset.RelationshipRepository + logger *logger.Logger +} + +// NewRelationshipSuggestionService creates a new RelationshipSuggestionService. +func NewRelationshipSuggestionService( + suggestionRepo relationship.SuggestionRepository, + assetRepo asset.Repository, + relRepo asset.RelationshipRepository, + log *logger.Logger, +) *RelationshipSuggestionService { + return &RelationshipSuggestionService{ + suggestionRepo: suggestionRepo, + assetRepo: assetRepo, + relRepo: relRepo, + logger: log.With("service", "relationship_suggestion"), + } +} + +// GenerateSuggestions analyzes assets and generates relationship suggestions. +// It creates suggestions for: +// - Domain contains subdomain: contains relationship (parent → child) +// - Domain/subdomain with resolved_ip -> IP address asset: resolves_to relationship +func (s *RelationshipSuggestionService) GenerateSuggestions(ctx context.Context, tenantID string) (int, error) { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return 0, fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + + s.logger.Info("generating relationship suggestions", "tenant_id", tenantID) + + // Clean up stale pending suggestions before regenerating. + if cleanErr := s.suggestionRepo.DeletePending(ctx, parsedTenantID); cleanErr != nil { + s.logger.Warn("failed to clean pending suggestions", "error", cleanErr) + } + + // Fetch all assets by type using pagination loop to handle large datasets. + domains, err := s.fetchAllAssets(ctx, tenantID, asset.AssetTypeDomain) + if err != nil { + return 0, fmt.Errorf("failed to list domains: %w", err) + } + + subdomains, err := s.fetchAllAssets(ctx, tenantID, asset.AssetTypeSubdomain) + if err != nil { + return 0, fmt.Errorf("failed to list subdomains: %w", err) + } + + ips, err := s.fetchAllAssets(ctx, tenantID, asset.AssetTypeIPAddress) + if err != nil { + return 0, fmt.Errorf("failed to list IP addresses: %w", err) + } + + // Build lookup maps + domainMap := make(map[string]*asset.Asset, len(domains)) + for _, d := range domains { + domainMap[d.Name()] = d + } + + ipMap := make(map[string]*asset.Asset, len(ips)) + for _, ip := range ips { + ipMap[ip.Name()] = ip + } + + suggestions := make([]*relationship.Suggestion, 0) + + // Generate domain contains subdomain suggestions (parent → child) + for _, sub := range subdomains { + parentDomain := findParentDomain(sub.Name(), domainMap) + if parentDomain != nil { + suggestion, suggErr := relationship.NewSuggestion( + parsedTenantID, + parentDomain.ID(), + sub.ID(), + string(asset.RelTypeContains), + fmt.Sprintf("Domain %s contains subdomain %s", parentDomain.Name(), sub.Name()), + 0.95, + ) + if suggErr != nil { + s.logger.Warn("failed to create member_of suggestion", "error", suggErr) + continue + } + suggestions = append(suggestions, suggestion) + } + } + + // Generate resolves_to suggestions for domains/subdomains with resolved_ip + allDNSAssets := make([]*asset.Asset, 0, len(domains)+len(subdomains)) + allDNSAssets = append(allDNSAssets, domains...) + allDNSAssets = append(allDNSAssets, subdomains...) + + for _, dnsAsset := range allDNSAssets { + resolvedIP := getResolvedIP(dnsAsset) + if resolvedIP == "" { + continue + } + + ipAsset, found := ipMap[resolvedIP] + if !found { + continue + } + + suggestion, suggErr := relationship.NewSuggestion( + parsedTenantID, + dnsAsset.ID(), + ipAsset.ID(), + string(asset.RelTypeResolvesTo), + fmt.Sprintf("%s resolves to IP %s", dnsAsset.Name(), resolvedIP), + 0.90, + ) + if suggErr != nil { + s.logger.Warn("failed to create resolves_to suggestion", "error", suggErr) + continue + } + suggestions = append(suggestions, suggestion) + } + + if len(suggestions) == 0 { + s.logger.Info("no suggestions generated", "tenant_id", tenantID) + return 0, nil + } + + created, err := s.suggestionRepo.CreateBatch(ctx, suggestions) + if err != nil { + return 0, fmt.Errorf("failed to create suggestions: %w", err) + } + + s.logger.Info("suggestions generated", "tenant_id", tenantID, "total", len(suggestions), "created", created) + return created, nil +} + +// ListPending returns pending suggestions for a tenant, optionally filtered by search. +func (s *RelationshipSuggestionService) ListPending(ctx context.Context, tenantID string, search string, page pagination.Pagination) (pagination.Result[*relationship.Suggestion], error) { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return pagination.Result[*relationship.Suggestion]{}, fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + + return s.suggestionRepo.ListPending(ctx, parsedTenantID, search, page) +} + +// ApproveBatch approves multiple suggestions by IDs. +// Returns (approved count, error). Returns error only if ALL items failed. +func (s *RelationshipSuggestionService) ApproveBatch(ctx context.Context, tenantID string, ids []string, reviewerID string) (int, error) { + const maxBatchSize = 1000 + if len(ids) > maxBatchSize { + return 0, fmt.Errorf("%w: batch size exceeds maximum of %d", shared.ErrValidation, maxBatchSize) + } + + approved := 0 + for _, id := range ids { + if err := s.Approve(ctx, tenantID, id, reviewerID); err != nil { + s.logger.Warn("failed to approve suggestion in batch", "id", id, "error", err) + continue + } + approved++ + } + + if approved == 0 && len(ids) > 0 { + return 0, fmt.Errorf("failed to approve any of the %d suggestions", len(ids)) + } + + return approved, nil +} + +// Approve approves a suggestion and creates the real relationship. +func (s *RelationshipSuggestionService) Approve(ctx context.Context, tenantID, suggestionID, reviewerID string) error { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + parsedSuggestionID, err := shared.IDFromString(suggestionID) + if err != nil { + return fmt.Errorf("%w: invalid suggestion ID", shared.ErrValidation) + } + parsedReviewerID, err := shared.IDFromString(reviewerID) + if err != nil { + return fmt.Errorf("%w: invalid reviewer ID", shared.ErrValidation) + } + + // Fetch the suggestion + suggestion, err := s.suggestionRepo.GetByID(ctx, parsedTenantID, parsedSuggestionID) + if err != nil { + return err + } + + if suggestion.Status() != relationship.SuggestionPending { + return fmt.Errorf("%w: suggestion is not pending", shared.ErrValidation) + } + + // Create the real relationship + relType, parseErr := asset.ParseRelationshipType(suggestion.RelationshipType()) + if parseErr != nil { + return fmt.Errorf("invalid relationship type in suggestion: %w", parseErr) + } + + rel, relErr := asset.NewRelationship( + parsedTenantID, + suggestion.SourceAssetID(), + suggestion.TargetAssetID(), + relType, + ) + if relErr != nil { + return fmt.Errorf("failed to create relationship from suggestion: %w", relErr) + } + rel.SetDescription(suggestion.Reason()) + + if createErr := s.relRepo.Create(ctx, rel); createErr != nil { + // If relationship already exists, still mark suggestion as approved + if !isAlreadyExists(createErr) { + return fmt.Errorf("failed to persist relationship: %w", createErr) + } + s.logger.Info("relationship already exists, marking suggestion as approved", "suggestion_id", suggestionID) + } + + // Mark as approved + suggestion.Approve(parsedReviewerID) + if updateErr := s.suggestionRepo.UpdateStatus(ctx, suggestion); updateErr != nil { + return fmt.Errorf("failed to update suggestion status: %w", updateErr) + } + + return nil +} + +// ApproveAll approves all pending suggestions and creates relationships for each. +func (s *RelationshipSuggestionService) ApproveAll(ctx context.Context, tenantID, reviewerID string) (int, error) { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return 0, fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + parsedReviewerID, err := shared.IDFromString(reviewerID) + if err != nil { + return 0, fmt.Errorf("%w: invalid reviewer ID", shared.ErrValidation) + } + + // Approve all in DB and get the approved suggestions + approved, err := s.suggestionRepo.ApproveAll(ctx, parsedTenantID, parsedReviewerID) + if err != nil { + return 0, fmt.Errorf("failed to approve all suggestions: %w", err) + } + + // Create relationships for each approved suggestion + created := 0 + for _, suggestion := range approved { + relType, parseErr := asset.ParseRelationshipType(suggestion.RelationshipType()) + if parseErr != nil { + s.logger.Warn("skipping suggestion with invalid relationship type", "suggestion_id", suggestion.ID().String(), "type", suggestion.RelationshipType()) + continue + } + + rel, relErr := asset.NewRelationship( + parsedTenantID, + suggestion.SourceAssetID(), + suggestion.TargetAssetID(), + relType, + ) + if relErr != nil { + s.logger.Warn("failed to create relationship from suggestion", "suggestion_id", suggestion.ID().String(), "error", relErr) + continue + } + rel.SetDescription(suggestion.Reason()) + + if createErr := s.relRepo.Create(ctx, rel); createErr != nil { + if !isAlreadyExists(createErr) { + s.logger.Warn("failed to persist relationship", "suggestion_id", suggestion.ID().String(), "error", createErr) + } + continue + } + created++ + } + + s.logger.Info("bulk approved suggestions", "tenant_id", tenantID, "approved", len(approved), "relationships_created", created) + return len(approved), nil +} + +// UpdateRelationshipType changes the relationship type of a pending suggestion. +func (s *RelationshipSuggestionService) UpdateRelationshipType(ctx context.Context, tenantID, suggestionID, relType string) error { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + parsedSuggestionID, err := shared.IDFromString(suggestionID) + if err != nil { + return fmt.Errorf("%w: invalid suggestion ID", shared.ErrValidation) + } + if relType == "" { + return fmt.Errorf("%w: relationship type is required", shared.ErrValidation) + } + // Validate the relationship type is valid + if _, parseErr := asset.ParseRelationshipType(relType); parseErr != nil { + return fmt.Errorf("%w: invalid relationship type: %s", shared.ErrValidation, relType) + } + + return s.suggestionRepo.UpdateRelationshipType(ctx, parsedTenantID, parsedSuggestionID, relType) +} + +// Dismiss marks a suggestion as dismissed. +func (s *RelationshipSuggestionService) Dismiss(ctx context.Context, tenantID, suggestionID, reviewerID string) error { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + parsedSuggestionID, err := shared.IDFromString(suggestionID) + if err != nil { + return fmt.Errorf("%w: invalid suggestion ID", shared.ErrValidation) + } + parsedReviewerID, err := shared.IDFromString(reviewerID) + if err != nil { + return fmt.Errorf("%w: invalid reviewer ID", shared.ErrValidation) + } + + suggestion, err := s.suggestionRepo.GetByID(ctx, parsedTenantID, parsedSuggestionID) + if err != nil { + return err + } + + if suggestion.Status() != relationship.SuggestionPending { + return fmt.Errorf("%w: suggestion is not pending", shared.ErrValidation) + } + + suggestion.Dismiss(parsedReviewerID) + return s.suggestionRepo.UpdateStatus(ctx, suggestion) +} + +// CountPending returns the number of pending suggestions for a tenant. +func (s *RelationshipSuggestionService) CountPending(ctx context.Context, tenantID string) (int64, error) { + parsedTenantID, err := shared.IDFromString(tenantID) + if err != nil { + return 0, fmt.Errorf("%w: invalid tenant ID", shared.ErrValidation) + } + + return s.suggestionRepo.CountPending(ctx, parsedTenantID) +} + +// ============================================================================= +// Helpers +// ============================================================================= + +// findParentDomain extracts the parent domain from a subdomain name and looks it up. +// Example: "api.example.com" -> look for "example.com" in the map. +func findParentDomain(subdomainName string, domainMap map[string]*asset.Asset) *asset.Asset { + parts := strings.SplitN(subdomainName, ".", 2) + if len(parts) < 2 { + return nil + } + parentName := parts[1] + + // Direct lookup + if parent, ok := domainMap[parentName]; ok { + return parent + } + + // Try further up the hierarchy (e.g., "a.b.example.com" -> "b.example.com" -> "example.com") + for { + parts = strings.SplitN(parentName, ".", 2) + if len(parts) < 2 { + break + } + parentName = parts[1] + if parent, ok := domainMap[parentName]; ok { + return parent + } + } + + return nil +} + +// getResolvedIP extracts the resolved_ip property from an asset. +func getResolvedIP(a *asset.Asset) string { + props := a.Properties() + + // Check resolved_ip property + if ip, ok := props["resolved_ip"]; ok { + if ipStr, ok := ip.(string); ok && ipStr != "" { + return ipStr + } + } + + // Check resolved_ips (array — take the first) + if ips, ok := props["resolved_ips"]; ok { + switch v := ips.(type) { + case []any: + if len(v) > 0 { + if ipStr, ok := v[0].(string); ok { + return ipStr + } + } + case []string: + if len(v) > 0 { + return v[0] + } + } + } + + return "" +} + +// fetchAllAssets retrieves all assets of a given type for a tenant, paginating through all pages. +// This prevents the LIMIT 100 cap from silently truncating large datasets. +func (s *RelationshipSuggestionService) fetchAllAssets(ctx context.Context, tenantID string, assetType asset.AssetType) ([]*asset.Asset, error) { + const pageSize = 100 + filter := asset.Filter{ + TenantID: &tenantID, + Types: []asset.AssetType{assetType}, + } + + var all []*asset.Asset + for page := 1; ; page++ { + result, err := s.assetRepo.List(ctx, filter, asset.ListOptions{}, pagination.New(page, pageSize)) + if err != nil { + return nil, err + } + all = append(all, result.Data...) + if len(all) >= int(result.Total) || len(result.Data) < pageSize { + break + } + } + return all, nil +} + +// isAlreadyExists checks if an error is an "already exists" error. +func isAlreadyExists(err error) bool { + return err != nil && strings.Contains(err.Error(), "already exists") +} diff --git a/internal/app/remediation_campaign_service.go b/internal/app/remediation_campaign_service.go index 6c554cbd..fe935573 100644 --- a/internal/app/remediation_campaign_service.go +++ b/internal/app/remediation_campaign_service.go @@ -3,6 +3,7 @@ package app import ( "context" "fmt" + "time" "github.com/openctemio/api/pkg/domain/remediation" "github.com/openctemio/api/pkg/domain/shared" @@ -90,6 +91,49 @@ func (s *RemediationCampaignService) ListCampaigns(ctx context.Context, tenantID return s.repo.List(ctx, filter, page) } +// UpdateRemediationCampaignInput holds fields for partial campaign update. +type UpdateRemediationCampaignInput struct { + Name *string + Description *string + Priority *string + Tags []string + DueDate *time.Time +} + +// UpdateCampaign updates campaign fields (name, description, priority, tags, due_date). +func (s *RemediationCampaignService) UpdateCampaign(ctx context.Context, tenantID, campaignID string, input UpdateRemediationCampaignInput) (*remediation.Campaign, error) { + tid, _ := shared.IDFromString(tenantID) + cid, _ := shared.IDFromString(campaignID) + + campaign, err := s.repo.GetByID(ctx, tid, cid) + if err != nil { + return nil, err + } + + if input.Name != nil { + campaign.SetName(*input.Name) + } + if input.Description != nil { + campaign.SetDescription(*input.Description) + } + if input.Priority != nil { + campaign.SetPriority(remediation.CampaignPriority(*input.Priority)) + } + if input.Tags != nil { + campaign.SetTags(input.Tags) + } + if input.DueDate != nil { + campaign.SetDueDate(input.DueDate) + } + + if err := s.repo.Update(ctx, campaign); err != nil { + return nil, fmt.Errorf("failed to update campaign: %w", err) + } + + s.logger.Info("remediation campaign updated", "id", campaignID) + return campaign, nil +} + // UpdateCampaignStatus transitions campaign status. func (s *RemediationCampaignService) UpdateCampaignStatus(ctx context.Context, tenantID, campaignID, newStatus string) (*remediation.Campaign, error) { tid, _ := shared.IDFromString(tenantID) @@ -109,6 +153,12 @@ func (s *RemediationCampaignService) UpdateCampaignStatus(ctx context.Context, t err = campaign.StartValidation() case remediation.CampaignStatusCompleted: err = campaign.Complete() + // Record risk reduction: resolved / total as a simple risk metric + if err == nil && campaign.FindingCount() > 0 { + before := float64(campaign.FindingCount()) + after := float64(campaign.FindingCount() - campaign.ResolvedCount()) + campaign.RecordRiskReduction(before, after) + } case remediation.CampaignStatusCanceled: campaign.Cancel() default: diff --git a/internal/app/vulnerability_service.go b/internal/app/vulnerability_service.go index fecdf7c6..6e321c10 100644 --- a/internal/app/vulnerability_service.go +++ b/internal/app/vulnerability_service.go @@ -973,6 +973,12 @@ func (s *VulnerabilityService) UpdateFindingStatus(ctx context.Context, findingI } } + // TODO(workflow-dispatch): Wire WorkflowEventDispatcher.DispatchFindingEvent() + // here to auto-trigger workflows (e.g., verification scan) on status change. + // Currently, verification scans are triggered manually via POST /findings/{id}/request-verification-scan. + // To automate: inject WorkflowEventDispatcher into VulnerabilityService, call + // DispatchFindingEvent(ctx, FindingEvent{TenantID, Finding, EventType: "finding_status_changed"}) + // Notify assignee about status change (fire and forget) if oldStatus != status.String() { s.notifyAssignee(ctx, f, input.ActorID, diff --git a/internal/infra/controller/control_test_scheduler.go b/internal/infra/controller/control_test_scheduler.go new file mode 100644 index 00000000..e98d64c5 --- /dev/null +++ b/internal/infra/controller/control_test_scheduler.go @@ -0,0 +1,111 @@ +package controller + +import ( + "context" + "time" + + "github.com/openctemio/api/pkg/domain/simulation" + "github.com/openctemio/api/pkg/logger" +) + +// ControlTestSchedulerController automatically marks control tests as overdue +// when they have not been run within the configured stale window. +// +// Design: +// - Runs every 24 hours (daily sweep). +// - Any control test not tested for >StaleDays days is reset to "untested" +// so it surfaces in the Detection Coverage dashboard. +// - Never blocks other operations — failures are logged and skipped. +type ControlTestSchedulerController struct { + repo simulation.ControlTestRepository + config *ControlTestSchedulerConfig + logger *logger.Logger +} + +// ControlTestSchedulerConfig configures the controller. +type ControlTestSchedulerConfig struct { + // Interval is how often the scheduler runs (default: 24 hours). + Interval time.Duration + + // StaleDays is the number of days without a test before a control test + // is considered overdue and reset to "untested" (default: 30). + StaleDays int + + // BatchSize is the maximum number of overdue tests to process per cycle (default: 500). + BatchSize int + + // Logger is passed by the controller manager. + Logger *logger.Logger +} + +// NewControlTestSchedulerController creates a new controller. +func NewControlTestSchedulerController( + repo simulation.ControlTestRepository, + cfg *ControlTestSchedulerConfig, +) *ControlTestSchedulerController { + if cfg.Interval == 0 { + cfg.Interval = 24 * time.Hour + } + if cfg.StaleDays == 0 { + cfg.StaleDays = 30 + } + if cfg.BatchSize == 0 { + cfg.BatchSize = 500 + } + return &ControlTestSchedulerController{ + repo: repo, + config: cfg, + logger: cfg.Logger, + } +} + +// Name implements Controller. +func (c *ControlTestSchedulerController) Name() string { return "control-test-scheduler" } + +// Interval implements Controller. +func (c *ControlTestSchedulerController) Interval() time.Duration { return c.config.Interval } + +// Reconcile finds all overdue control tests and resets their status to "untested". +// Returns the count of tests marked overdue. +func (c *ControlTestSchedulerController) Reconcile(ctx context.Context) (int, error) { + overdueTests, err := c.repo.ListOverdue(ctx, c.config.StaleDays, c.config.BatchSize) + if err != nil { + return 0, err + } + + if len(overdueTests) == 0 { + return 0, nil + } + + marked := 0 + for _, ct := range overdueTests { + if err := c.repo.MarkOverdue(ctx, ct.TenantID, ct.ControlTestID); err != nil { + c.logger.Warn("failed to mark control test overdue", + "tenant_id", ct.TenantID.String(), + "control_test_id", ct.ControlTestID.String(), + "name", ct.Name, + "error", err, + ) + continue + } + + c.logger.Info("control test marked overdue", + "tenant_id", ct.TenantID.String(), + "control_test_id", ct.ControlTestID.String(), + "framework", ct.Framework, + "name", ct.Name, + "days_since_tested", ct.DaysSinceTested, + ) + marked++ + } + + if marked > 0 { + c.logger.Info("control test scheduler cycle completed", + "overdue_found", len(overdueTests), + "marked_untested", marked, + "stale_days", c.config.StaleDays, + ) + } + + return marked, nil +} diff --git a/internal/infra/controller/controller.go b/internal/infra/controller/controller.go index 531f5706..651f7912 100644 --- a/internal/infra/controller/controller.go +++ b/internal/infra/controller/controller.go @@ -215,12 +215,8 @@ func (m *Manager) reconcileOnce(ctx context.Context, c Controller) { "items_processed", count, "duration", duration, ) - } else { - m.logger.Debug("controller reconcile completed (no items)", - "name", name, - "duration", duration, - ) } + // No log when zero items — reduces noise in dev/prod logs if m.metrics != nil { m.metrics.RecordReconcile(name, count, duration, nil) } diff --git a/internal/infra/controller/threat_intel_refresh.go b/internal/infra/controller/threat_intel_refresh.go new file mode 100644 index 00000000..ca658221 --- /dev/null +++ b/internal/infra/controller/threat_intel_refresh.go @@ -0,0 +1,45 @@ +package controller + +import ( + "context" + "time" + + "github.com/openctemio/api/internal/app" + "github.com/openctemio/api/pkg/logger" +) + +// ThreatIntelRefreshController periodically refreshes EPSS scores and KEV catalog. +// Runs every 24 hours. Fetches latest data from FIRST.org (EPSS) and CISA (KEV), +// then persists to database via ThreatIntelService.SyncAll(). +type ThreatIntelRefreshController struct { + service *app.ThreatIntelService + logger *logger.Logger +} + +// NewThreatIntelRefreshController creates a new controller. +func NewThreatIntelRefreshController(service *app.ThreatIntelService, log *logger.Logger) *ThreatIntelRefreshController { + return &ThreatIntelRefreshController{service: service, logger: log} +} + +// Name returns the controller name. +func (c *ThreatIntelRefreshController) Name() string { return "threat-intel-refresh" } + +// Interval returns 24 hours — daily refresh. +func (c *ThreatIntelRefreshController) Interval() time.Duration { return 24 * time.Hour } + +// Reconcile fetches and persists latest EPSS + KEV data. +func (c *ThreatIntelRefreshController) Reconcile(ctx context.Context) (int, error) { + results := c.service.SyncAll(ctx) + + processed := 0 + for _, r := range results { + if r.Error != nil { + c.logger.Warn("threat intel sync failed", "source", r.Source, "error", r.Error) + } else { + processed += r.RecordsSynced + c.logger.Info("threat intel synced", "source", r.Source, "records", r.RecordsSynced, "duration_ms", r.DurationMs) + } + } + + return processed, nil +} diff --git a/internal/infra/http/handler/ai_triage_handler.go b/internal/infra/http/handler/ai_triage_handler.go index 158cc525..f57522a3 100644 --- a/internal/infra/http/handler/ai_triage_handler.go +++ b/internal/infra/http/handler/ai_triage_handler.go @@ -212,7 +212,7 @@ func (h *AITriageHandler) GetTriageResultByID(w http.ResponseWriter, r *http.Req tenantID := middleware.MustGetTenantID(r.Context()) findingID := r.PathValue("id") - triageID := r.PathValue("triage_id") + triageID := r.PathValue("triageId") if findingID == "" || triageID == "" { apierror.BadRequest("Finding ID and Triage ID are required").WriteJSON(w) return diff --git a/internal/infra/http/handler/asset_handler.go b/internal/infra/http/handler/asset_handler.go index abc4186b..2ae70979 100644 --- a/internal/infra/http/handler/asset_handler.go +++ b/internal/infra/http/handler/asset_handler.go @@ -1160,7 +1160,8 @@ type AssetStatsResponse struct { WithFindings int `json:"with_findings"` RiskScoreAvg float64 `json:"risk_score_avg"` FindingsTotal int `json:"findings_total"` - HighRiskCount int `json:"high_risk_count"` // Assets with risk_score >= 70 + HighRiskCount int `json:"high_risk_count"` + MetadataCounts map[string]map[string]int `json:"metadata_counts,omitempty"` } // GetStats handles GET /api/v1/assets/stats @@ -1183,25 +1184,29 @@ func (h *AssetHandler) GetStats(w http.ResponseWriter, r *http.Request) { tagsFilter := parseQueryArray(query.Get("tags")) subTypeFilter := query.Get("sub_type") + // Parse count_by fields for metadata counting (e.g., ?count_by=is_virtual,os,ssl) + countByFields := parseQueryArray(query.Get("count_by")) + // Use service method with SQL aggregation for efficient stats - aggStats, err := h.service.GetAssetStats(r.Context(), tenantID, typesFilter, tagsFilter, subTypeFilter) + aggStats, err := h.service.GetAssetStats(r.Context(), tenantID, typesFilter, tagsFilter, subTypeFilter, countByFields...) if err != nil { h.handleServiceError(w, err) return } stats := AssetStatsResponse{ - Total: aggStats.Total, - ByType: aggStats.ByType, - BySubType: aggStats.BySubType, - ByStatus: aggStats.ByStatus, - ByCriticality: aggStats.ByCriticality, - ByScope: aggStats.ByScope, - ByExposure: aggStats.ByExposure, - WithFindings: aggStats.WithFindings, - FindingsTotal: aggStats.FindingsTotal, - HighRiskCount: aggStats.HighRiskCount, - RiskScoreAvg: aggStats.RiskScoreAvg, + Total: aggStats.Total, + ByType: aggStats.ByType, + BySubType: aggStats.BySubType, + ByStatus: aggStats.ByStatus, + ByCriticality: aggStats.ByCriticality, + ByScope: aggStats.ByScope, + ByExposure: aggStats.ByExposure, + WithFindings: aggStats.WithFindings, + FindingsTotal: aggStats.FindingsTotal, + HighRiskCount: aggStats.HighRiskCount, + RiskScoreAvg: aggStats.RiskScoreAvg, + MetadataCounts: aggStats.MetadataCounts, } w.Header().Set("Content-Type", "application/json") @@ -1215,6 +1220,7 @@ func (h *AssetHandler) ListTags(w http.ResponseWriter, r *http.Request) { tenantID := middleware.MustGetTenantID(r.Context()) prefix := r.URL.Query().Get("prefix") + types := r.URL.Query()["type"] limitStr := r.URL.Query().Get("limit") limit := 50 if limitStr != "" { @@ -1223,7 +1229,7 @@ func (h *AssetHandler) ListTags(w http.ResponseWriter, r *http.Request) { } } - tags, err := h.service.ListTags(r.Context(), tenantID, prefix, limit) + tags, err := h.service.ListTags(r.Context(), tenantID, prefix, types, limit) if err != nil { h.handleServiceError(w, err) return diff --git a/internal/infra/http/handler/asset_owner_handler.go b/internal/infra/http/handler/asset_owner_handler.go index 0ca617c5..fca76b27 100644 --- a/internal/infra/http/handler/asset_owner_handler.go +++ b/internal/infra/http/handler/asset_owner_handler.go @@ -280,7 +280,7 @@ func (h *AssetOwnerHandler) AddOwner(w http.ResponseWriter, r *http.Request) { // UpdateOwner handles PUT /api/v1/assets/{id}/owners/{ownerID} func (h *AssetOwnerHandler) UpdateOwner(w http.ResponseWriter, r *http.Request) { assetID := r.PathValue("id") - ownerID := r.PathValue("ownerID") + ownerID := r.PathValue("ownerId") if ownerID == "" { apierror.BadRequest("Owner ID is required").WriteJSON(w) return @@ -344,7 +344,7 @@ func (h *AssetOwnerHandler) UpdateOwner(w http.ResponseWriter, r *http.Request) // RemoveOwner handles DELETE /api/v1/assets/{id}/owners/{ownerID} func (h *AssetOwnerHandler) RemoveOwner(w http.ResponseWriter, r *http.Request) { assetID := r.PathValue("id") - ownerID := r.PathValue("ownerID") + ownerID := r.PathValue("ownerId") if ownerID == "" { apierror.BadRequest("Owner ID is required").WriteJSON(w) return diff --git a/internal/infra/http/handler/asset_service_handler.go b/internal/infra/http/handler/asset_service_handler.go index bc4ee6dd..51025a9c 100644 --- a/internal/infra/http/handler/asset_service_handler.go +++ b/internal/infra/http/handler/asset_service_handler.go @@ -598,8 +598,9 @@ func (h *AssetServiceHandler) List(w http.ResponseWriter, r *http.Request) { } // ListPublic handles GET /api/v1/services/public -// @Summary List public services -// @Description Retrieves a paginated list of publicly exposed services +// @Summary List public services (deprecated) +// @Description Deprecated: use GET /services?exposure=public instead. +// @Description Retrieves a paginated list of publicly exposed services. // @Tags Asset Services // @Accept json // @Produce json @@ -611,6 +612,8 @@ func (h *AssetServiceHandler) List(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} apierror.Error // @Router /services/public [get] func (h *AssetServiceHandler) ListPublic(w http.ResponseWriter, r *http.Request) { + // Deprecated: delegates to List with exposure=public pre-set. + // Use GET /services?exposure=public instead. ctx := r.Context() tenantIDStr := middleware.MustGetTenantID(ctx) tenantID, err := shared.IDFromString(tenantIDStr) @@ -619,25 +622,35 @@ func (h *AssetServiceHandler) ListPublic(w http.ResponseWriter, r *http.Request) return } - limit := 50 - offset := 0 + opts := asset.DefaultListAssetServicesOptions() + + // Pre-set is_public=true; callers can still override via ?is_public= query param. + // This is the backward-compat equivalent of the old /services/public endpoint. + if v := r.URL.Query().Get("is_public"); v != "" { + isPublic := v == queryParamTrue + opts.IsPublic = &isPublic + } else { + isPublicDefault := true + opts.IsPublic = &isPublicDefault + } + if v := r.URL.Query().Get("limit"); v != "" { if l, err := strconv.Atoi(v); err == nil && l > 0 { - limit = l + opts.Limit = l } } if v := r.URL.Query().Get("offset"); v != "" { if o, err := strconv.Atoi(v); err == nil && o >= 0 { - offset = o + opts.Offset = o } } // Security: Enforce max limit to prevent DoS via large queries const maxLimit = 1000 - if limit > maxLimit { - limit = maxLimit + if opts.Limit > maxLimit { + opts.Limit = maxLimit } - services, total, err := h.repo.ListPublic(ctx, tenantID, limit, offset) + services, total, err := h.repo.List(ctx, tenantID, opts) if err != nil { h.logger.Error("failed to list public services", "error", err) apierror.InternalError(err).WriteJSON(w) @@ -653,8 +666,8 @@ func (h *AssetServiceHandler) ListPublic(w http.ResponseWriter, r *http.Request) json.NewEncoder(w).Encode(map[string]interface{}{ "data": response, "total": total, - "limit": limit, - "offset": offset, + "limit": opts.Limit, + "offset": opts.Offset, }) } diff --git a/internal/infra/http/handler/asset_state_history_handler.go b/internal/infra/http/handler/asset_state_history_handler.go index 19ceacad..93c1c371 100644 --- a/internal/infra/http/handler/asset_state_history_handler.go +++ b/internal/infra/http/handler/asset_state_history_handler.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "strconv" + "strings" "time" "github.com/openctemio/api/internal/infra/http/middleware" @@ -148,12 +149,15 @@ func (h *AssetStateHistoryHandler) ListByAsset(w http.ResponseWriter, r *http.Re // List handles GET /api/v1/state-history // @Summary List all state history -// @Description Retrieves a paginated list of all state changes for the tenant with optional filtering +// @Description Retrieves a paginated list of all state changes for the tenant with optional filtering. +// @Description Use ?event_type= with comma-separated values to filter by one or more change types +// @Description (e.g. ?event_type=appeared,disappeared replaces the old /appearances and /disappearances endpoints). // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth -// @Param change_type query string false "Filter by change type" +// @Param event_type query string false "Comma-separated change types (appeared,disappeared,shadow_it,exposure_changed,...)" +// @Param change_type query string false "Filter by single change type (deprecated: use event_type)" // @Param source query string false "Filter by source" // @Param from query string false "Start time (RFC3339)" // @Param to query string false "End time (RFC3339)" @@ -241,107 +245,62 @@ func (h *AssetStateHistoryHandler) Get(w http.ResponseWriter, r *http.Request) { } // RecentAppearances handles GET /api/v1/state-history/appearances -// @Summary Get recent asset appearances -// @Description Retrieves recently discovered assets (new assets appearing in scans) +// @Summary Get recent asset appearances (deprecated) +// @Description Deprecated: use GET /state-history?event_type=appeared instead. +// @Description Retrieves recently discovered assets (new assets appearing in scans). // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/appearances [get] func (h *AssetStateHistoryHandler) RecentAppearances(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - tenantIDStr := middleware.MustGetTenantID(ctx) - tenantID, err := shared.IDFromString(tenantIDStr) - if err != nil { - apierror.Unauthorized("Invalid tenant ID").WriteJSON(w) - return - } - - since, limit := h.parseSinceAndLimit(r) - - changes, err := h.repo.GetRecentAppearances(ctx, tenantID, since, limit) - if err != nil { - h.logger.Error("failed to get recent appearances", "error", err) - apierror.InternalError(err).WriteJSON(w) - return - } - - response := make([]StateChangeResponse, len(changes)) - for i, change := range changes { - response[i] = toStateChangeResponse(change) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), - }) + // Deprecated: delegates to List with event_type=appeared pre-set. + // Use GET /state-history?event_type=appeared instead. + h.listWithPresetEventTypes(w, r, asset.StateChangeAppeared) } // RecentDisappearances handles GET /api/v1/state-history/disappearances -// @Summary Get recent asset disappearances -// @Description Retrieves assets that have disappeared (no longer seen in scans) +// @Summary Get recent asset disappearances (deprecated) +// @Description Deprecated: use GET /state-history?event_type=disappeared instead. +// @Description Retrieves assets that have disappeared (no longer seen in scans). // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/disappearances [get] func (h *AssetStateHistoryHandler) RecentDisappearances(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - tenantIDStr := middleware.MustGetTenantID(ctx) - tenantID, err := shared.IDFromString(tenantIDStr) - if err != nil { - apierror.Unauthorized("Invalid tenant ID").WriteJSON(w) - return - } - - since, limit := h.parseSinceAndLimit(r) - - changes, err := h.repo.GetRecentDisappearances(ctx, tenantID, since, limit) - if err != nil { - h.logger.Error("failed to get recent disappearances", "error", err) - apierror.InternalError(err).WriteJSON(w) - return - } - - response := make([]StateChangeResponse, len(changes)) - for i, change := range changes { - response[i] = toStateChangeResponse(change) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), - }) + // Deprecated: delegates to List with event_type=disappeared pre-set. + // Use GET /state-history?event_type=disappeared instead. + h.listWithPresetEventTypes(w, r, asset.StateChangeDisappeared) } // ShadowITCandidates handles GET /api/v1/state-history/shadow-it -// @Summary Get Shadow IT candidates -// @Description Retrieves assets identified as potential Shadow IT (unexpected or unauthorized resources) +// @Summary Get Shadow IT candidates (deprecated) +// @Description Deprecated: use GET /state-history?event_type=appeared with scope filtering instead. +// @Description Retrieves assets identified as potential Shadow IT (appeared with shadow scope). // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/shadow-it [get] func (h *AssetStateHistoryHandler) ShadowITCandidates(w http.ResponseWriter, r *http.Request) { + // Deprecated: shadow-it uses a specialized JOIN query (scope='shadow'). + // Keep delegating to the dedicated repo method for correctness. ctx := r.Context() tenantIDStr := middleware.MustGetTenantID(ctx) tenantID, err := shared.IDFromString(tenantIDStr) @@ -366,114 +325,81 @@ func (h *AssetStateHistoryHandler) ShadowITCandidates(w http.ResponseWriter, r * w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), + "data": response, + "total": len(response), + "limit": limit, + "offset": 0, }) } // ExposureChanges handles GET /api/v1/state-history/exposure-changes -// @Summary Get exposure changes -// @Description Retrieves assets that have changed exposure status (public/private/restricted) +// @Summary Get exposure changes (deprecated) +// @Description Deprecated: use GET /state-history?event_type=exposure_changed,internet_exposure_changed instead. +// @Description Retrieves assets that have changed exposure status (public/private/restricted). // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/exposure-changes [get] func (h *AssetStateHistoryHandler) ExposureChanges(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - tenantIDStr := middleware.MustGetTenantID(ctx) - tenantID, err := shared.IDFromString(tenantIDStr) - if err != nil { - apierror.Unauthorized("Invalid tenant ID").WriteJSON(w) - return - } - - since, limit := h.parseSinceAndLimit(r) - - changes, err := h.repo.GetExposureChanges(ctx, tenantID, since, limit) - if err != nil { - h.logger.Error("failed to get exposure changes", "error", err) - apierror.InternalError(err).WriteJSON(w) - return - } - - response := make([]StateChangeResponse, len(changes)) - for i, change := range changes { - response[i] = toStateChangeResponse(change) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), - }) + // Deprecated: delegates to List with event_type=exposure_changed,internet_exposure_changed pre-set. + // Use GET /state-history?event_type=exposure_changed,internet_exposure_changed instead. + h.listWithPresetEventTypes(w, r, asset.StateChangeExposureChanged, asset.StateChangeInternetExposureChanged) } // NewlyExposed handles GET /api/v1/state-history/newly-exposed -// @Summary Get newly exposed assets -// @Description Retrieves assets that have recently become publicly exposed +// @Summary Get newly exposed assets (deprecated) +// @Description Deprecated: use GET /state-history?event_type=internet_exposure_changed instead. +// @Description Retrieves assets that have recently become publicly exposed. // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/newly-exposed [get] func (h *AssetStateHistoryHandler) NewlyExposed(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - tenantIDStr := middleware.MustGetTenantID(ctx) - tenantID, err := shared.IDFromString(tenantIDStr) - if err != nil { - apierror.Unauthorized("Invalid tenant ID").WriteJSON(w) - return - } - - since, limit := h.parseSinceAndLimit(r) - - changes, err := h.repo.GetNewlyExposedAssets(ctx, tenantID, since, limit) - if err != nil { - h.logger.Error("failed to get newly exposed assets", "error", err) - apierror.InternalError(err).WriteJSON(w) - return - } - - response := make([]StateChangeResponse, len(changes)) - for i, change := range changes { - response[i] = toStateChangeResponse(change) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), - }) + // Deprecated: delegates to List with event_type=internet_exposure_changed pre-set. + // Use GET /state-history?event_type=internet_exposure_changed instead. + h.listWithPresetEventTypes(w, r, asset.StateChangeInternetExposureChanged) } // ComplianceChanges handles GET /api/v1/state-history/compliance -// @Summary Get compliance-related changes -// @Description Retrieves state changes that may affect compliance status +// @Summary Get compliance-related changes (deprecated) +// @Description Deprecated: use GET /state-history?event_type=compliance_changed,classification_changed,owner_changed instead. +// @Description Retrieves state changes that may affect compliance status. // @Tags Asset State History // @Accept json // @Produce json // @Security BearerAuth // @Param since query string false "Start time (RFC3339, default: 7 days ago)" // @Param limit query int false "Maximum results (max 1000)" default(100) -// @Success 200 {object} object{data=[]StateChangeResponse,total=int,since=string} +// @Success 200 {object} object{data=[]StateChangeResponse,total=int,limit=int,offset=int} // @Failure 401 {object} apierror.Error // @Failure 500 {object} apierror.Error // @Router /state-history/compliance [get] func (h *AssetStateHistoryHandler) ComplianceChanges(w http.ResponseWriter, r *http.Request) { + // Deprecated: delegates to List with compliance event types pre-set. + // Use GET /state-history?event_type=compliance_changed,classification_changed,owner_changed instead. + h.listWithPresetEventTypes(w, r, + asset.StateChangeComplianceChanged, + asset.StateChangeClassificationChanged, + asset.StateChangeOwnerChanged, + ) +} + +// listWithPresetEventTypes is a shared helper that serves the deprecated specialized endpoints. +// It sets the given event types as defaults and then falls through to the normal List handler, +// allowing callers to still override via query params if needed. +func (h *AssetStateHistoryHandler) listWithPresetEventTypes(w http.ResponseWriter, r *http.Request, defaultTypes ...asset.StateChangeType) { ctx := r.Context() tenantIDStr := middleware.MustGetTenantID(ctx) tenantID, err := shared.IDFromString(tenantIDStr) @@ -482,11 +408,30 @@ func (h *AssetStateHistoryHandler) ComplianceChanges(w http.ResponseWriter, r *h return } - since, limit := h.parseSinceAndLimit(r) + // Parse options normally — the caller may still pass ?event_type= to override. + opts := h.parseListOptions(r) - changes, err := h.repo.GetComplianceChanges(ctx, tenantID, since, limit) + // Apply pre-set defaults only when no explicit event_type / change_type filter was provided. + if opts.ChangeType == nil && len(opts.ChangeTypes) == 0 { + if len(defaultTypes) == 1 { + opts.ChangeType = &defaultTypes[0] + } else { + opts.ChangeTypes = defaultTypes + } + } + + // Apply ?since= as a From bound when no explicit ?from= was given (backward compat). + if opts.From == nil { + since, limit := h.parseSinceAndLimit(r) + opts.From = &since + if opts.Limit == asset.DefaultListStateHistoryOptions().Limit { + opts.Limit = limit + } + } + + changes, total, err := h.repo.List(ctx, tenantID, opts) if err != nil { - h.logger.Error("failed to get compliance changes", "error", err) + h.logger.Error("failed to list state history", "error", err) apierror.InternalError(err).WriteJSON(w) return } @@ -498,9 +443,10 @@ func (h *AssetStateHistoryHandler) ComplianceChanges(w http.ResponseWriter, r *h w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]interface{}{ - "data": response, - "total": len(response), - "since": since.Format(time.RFC3339), + "data": response, + "total": total, + "limit": opts.Limit, + "offset": opts.Offset, }) } @@ -636,7 +582,23 @@ func (h *AssetStateHistoryHandler) Stats(w http.ResponseWriter, r *http.Request) func (h *AssetStateHistoryHandler) parseListOptions(r *http.Request) asset.ListStateHistoryOptions { opts := asset.DefaultListStateHistoryOptions() - if v := r.URL.Query().Get("change_type"); v != "" { + // event_type supports comma-separated list of change types, e.g. ?event_type=appeared,disappeared + // This is the preferred param going forward; change_type is kept for backward compatibility. + if v := r.URL.Query().Get("event_type"); v != "" { + parts := strings.Split(v, ",") + types := make([]asset.StateChangeType, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + types = append(types, asset.StateChangeType(p)) + } + } + if len(types) == 1 { + opts.ChangeType = &types[0] + } else if len(types) > 1 { + opts.ChangeTypes = types + } + } else if v := r.URL.Query().Get("change_type"); v != "" { ct := asset.StateChangeType(v) opts.ChangeType = &ct } diff --git a/internal/infra/http/handler/attack_surface_handler.go b/internal/infra/http/handler/attack_surface_handler.go index ab4d27d9..14a99e75 100644 --- a/internal/infra/http/handler/attack_surface_handler.go +++ b/internal/infra/http/handler/attack_surface_handler.go @@ -116,6 +116,104 @@ func (h *AttackSurfaceHandler) GetStats(w http.ResponseWriter, r *http.Request) _ = json.NewEncoder(w).Encode(response) } +// AttackPathScoreResponse represents a single asset with its attack path score. +type AttackPathScoreResponse struct { + AssetID string `json:"asset_id"` + Name string `json:"name"` + AssetType string `json:"asset_type"` + Exposure string `json:"exposure"` + Criticality string `json:"criticality"` + RiskScore int `json:"risk_score"` + IsCrownJewel bool `json:"is_crown_jewel"` + FindingCount int `json:"finding_count"` + ReachableFrom int `json:"reachable_from"` + PathScore float64 `json:"path_score"` + IsEntryPoint bool `json:"is_entry_point"` + IsProtected bool `json:"is_protected"` +} + +// AttackPathSummaryResponse holds aggregate attack path metrics. +type AttackPathSummaryResponse struct { + TotalPaths int `json:"total_paths"` + EntryPoints int `json:"entry_points"` + ReachableAssets int `json:"reachable_assets"` + MaxDepth int `json:"max_depth"` + CriticalReachable int `json:"critical_reachable"` + CrownJewelsAtRisk int `json:"crown_jewels_at_risk"` + HasRelationshipData bool `json:"has_relationship_data"` +} + +// AttackPathScoringResponse is the response for the attack path scoring endpoint. +type AttackPathScoringResponse struct { + Summary AttackPathSummaryResponse `json:"summary"` + TopAssets []AttackPathScoreResponse `json:"top_assets"` +} + +// GetAttackPaths computes attack path scoring for the current tenant. +// @Summary Get attack path scoring +// @Description Computes reachability-based attack path scores for all assets +// @Tags Attack Surface +// @Produce json +// @Security BearerAuth +// @Success 200 {object} AttackPathScoringResponse +// @Failure 400 {object} apierror.Error +// @Failure 401 {object} apierror.Error +// @Failure 500 {object} apierror.Error +// @Router /attack-surface/attack-paths [get] +func (h *AttackSurfaceHandler) GetAttackPaths(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + tenantIDStr := middleware.MustGetTenantID(ctx) + tenantID, err := shared.IDFromString(tenantIDStr) + if err != nil { + apierror.BadRequest("Invalid tenant ID format").WriteJSON(w) + return + } + + result, err := h.service.GetAttackPathScores(ctx, tenantID) + if err != nil { + h.logger.Error("failed to compute attack path scores", "error", err) + apierror.InternalError(err).WriteJSON(w) + return + } + + // Map to response + topAssets := make([]AttackPathScoreResponse, len(result.TopAssets)) + for i, a := range result.TopAssets { + topAssets[i] = AttackPathScoreResponse{ + AssetID: a.AssetID, + Name: a.Name, + AssetType: a.AssetType, + Exposure: a.Exposure, + Criticality: a.Criticality, + RiskScore: a.RiskScore, + IsCrownJewel: a.IsCrownJewel, + FindingCount: a.FindingCount, + ReachableFrom: a.ReachableFrom, + PathScore: a.PathScore, + IsEntryPoint: a.IsEntryPoint, + IsProtected: a.IsProtected, + } + } + + response := AttackPathScoringResponse{ + Summary: AttackPathSummaryResponse{ + TotalPaths: result.Summary.TotalPaths, + EntryPoints: result.Summary.EntryPoints, + ReachableAssets: result.Summary.ReachableAssets, + MaxDepth: result.Summary.MaxDepth, + CriticalReachable: result.Summary.CriticalReachable, + CrownJewelsAtRisk: result.Summary.CrownJewelsAtRisk, + HasRelationshipData: result.Summary.HasRelationshipData, + }, + TopAssets: topAssets, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) +} + // toStatsResponse converts service stats to API response. func (h *AttackSurfaceHandler) toStatsResponse(stats *app.AttackSurfaceStats) AttackSurfaceStatsResponse { // Convert asset breakdown diff --git a/internal/infra/http/handler/business_unit_handler.go b/internal/infra/http/handler/business_unit_handler.go index a322a3c5..0d5a5f48 100644 --- a/internal/infra/http/handler/business_unit_handler.go +++ b/internal/infra/http/handler/business_unit_handler.go @@ -68,6 +68,26 @@ func (h *BusinessUnitHandler) Get(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, toBUResp(bu)) } +// Update updates a business unit. +func (h *BusinessUnitHandler) Update(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + buID := chi.URLParam(r, "id") + var req CreateBURequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + bu, err := h.service.Update(r.Context(), app.UpdateBusinessUnitInput{ + TenantID: tenantID, ID: buID, Name: req.Name, Description: req.Description, + OwnerName: req.OwnerName, OwnerEmail: req.OwnerEmail, Tags: req.Tags, + }) + if err != nil { + h.handleError(w, err) + return + } + writeJSON(w, http.StatusOK, toBUResp(bu)) +} + // Delete deletes a business unit. func (h *BusinessUnitHandler) Delete(w http.ResponseWriter, r *http.Request) { tenantID := middleware.MustGetTenantID(r.Context()) diff --git a/internal/infra/http/handler/finding_actions_handler.go b/internal/infra/http/handler/finding_actions_handler.go index a8e70ad2..bb957d5a 100644 --- a/internal/infra/http/handler/finding_actions_handler.go +++ b/internal/infra/http/handler/finding_actions_handler.go @@ -286,6 +286,50 @@ func (h *FindingActionsHandler) RejectFix(w http.ResponseWriter, r *http.Request h.writeJSON(w, http.StatusOK, result) } +// --- Request Verification Scan --- + +// RequestVerificationScanRequest is the request body for POST /api/v1/findings/{id}/request-verification. +type RequestVerificationScanRequest struct { + ScannerName string `json:"scanner_name" validate:"omitempty,max=100"` + WorkflowID string `json:"workflow_id" validate:"omitempty,uuid"` +} + +// RequestVerificationScan handles POST /api/v1/findings/{id}/request-verification +// Triggers a targeted quick scan on the asset associated with a fix_applied finding. +func (h *FindingActionsHandler) RequestVerificationScan(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + userID := middleware.GetLocalUserID(r.Context()) + findingID := chi.URLParam(r, "id") + + if findingID == "" { + apierror.BadRequest("finding id is required").WriteJSON(w) + return + } + + var req RequestVerificationScanRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("Invalid request body").WriteJSON(w) + return + } + + if len(req.ScannerName) > 100 { + apierror.BadRequest("scanner_name must be at most 100 characters").WriteJSON(w) + return + } + + result, err := h.service.RequestVerificationScan(r.Context(), tenantID, userID.String(), app.RequestVerificationScanInput{ + FindingID: findingID, + ScannerName: req.ScannerName, + WorkflowID: req.WorkflowID, + }) + if err != nil { + h.handleError(w, err) + return + } + + h.writeJSON(w, http.StatusAccepted, result) +} + // --- Auto-Assign --- // AssignToOwnersRequest is the request body for POST /api/v1/findings/actions/assign-to-owners diff --git a/internal/infra/http/handler/finding_activity_handler.go b/internal/infra/http/handler/finding_activity_handler.go index 70a047f5..ea7db588 100644 --- a/internal/infra/http/handler/finding_activity_handler.go +++ b/internal/infra/http/handler/finding_activity_handler.go @@ -156,7 +156,7 @@ func (h *FindingActivityHandler) GetActivity(w http.ResponseWriter, r *http.Requ tenantID := middleware.MustGetTenantID(r.Context()) findingID := r.PathValue("id") - activityID := r.PathValue("activity_id") + activityID := r.PathValue("activityId") if findingID == "" || activityID == "" { apierror.BadRequest("Finding ID and Activity ID are required").WriteJSON(w) return diff --git a/internal/infra/http/handler/jira_webhook_handler.go b/internal/infra/http/handler/jira_webhook_handler.go new file mode 100644 index 00000000..f56efd8b --- /dev/null +++ b/internal/infra/http/handler/jira_webhook_handler.go @@ -0,0 +1,165 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/openctemio/api/internal/app" + "github.com/openctemio/api/internal/infra/http/middleware" + "github.com/openctemio/api/pkg/apierror" + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/logger" +) + +// JiraWebhookHandler handles Jira bidirectional ticket sync endpoints. +// +// Endpoints: +// - POST /api/v1/findings/{id}/link-ticket — link a Jira ticket to a finding +// - DELETE /api/v1/findings/{id}/link-ticket — unlink a Jira ticket from a finding +// - POST /api/v1/webhooks/incoming/jira — receive Jira status-change webhooks +type JiraWebhookHandler struct { + service *app.JiraSyncService + logger *logger.Logger +} + +// NewJiraWebhookHandler creates a new JiraWebhookHandler. +func NewJiraWebhookHandler(svc *app.JiraSyncService, log *logger.Logger) *JiraWebhookHandler { + return &JiraWebhookHandler{service: svc, logger: log} +} + +// LinkTicketRequest is the request body for POST /api/v1/findings/{id}/link-ticket. +type LinkTicketRequest struct { + TicketKey string `json:"ticket_key" validate:"required,min=1,max=255"` + TicketURL string `json:"ticket_url" validate:"required,url,max=1000"` +} + +// UnlinkTicketRequest is the request body for DELETE /api/v1/findings/{id}/link-ticket. +type UnlinkTicketRequest struct { + TicketURL string `json:"ticket_url" validate:"required,url,max=1000"` +} + +// LinkTicket handles POST /api/v1/findings/{id}/link-ticket. +// Links a Jira ticket to a finding by storing its URL in work_item_uris. +func (h *JiraWebhookHandler) LinkTicket(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + findingID := chi.URLParam(r, "id") + if findingID == "" { + apierror.BadRequest("finding id is required").WriteJSON(w) + return + } + + var req LinkTicketRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + + if req.TicketKey == "" { + apierror.BadRequest("ticket_key is required").WriteJSON(w) + return + } + if req.TicketURL == "" { + apierror.BadRequest("ticket_url is required").WriteJSON(w) + return + } + + input := app.LinkTicketInput{ + TenantID: tenantID, + FindingID: findingID, + TicketKey: req.TicketKey, + TicketURL: req.TicketURL, + } + + if err := h.service.LinkTicket(r.Context(), input); err != nil { + h.handleError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{ + "finding_id": findingID, + "ticket_key": req.TicketKey, + "ticket_url": req.TicketURL, + "message": "ticket linked successfully", + }) +} + +// UnlinkTicket handles DELETE /api/v1/findings/{id}/link-ticket. +// Removes a Jira ticket reference from a finding's work_item_uris. +func (h *JiraWebhookHandler) UnlinkTicket(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + findingID := chi.URLParam(r, "id") + if findingID == "" { + apierror.BadRequest("finding id is required").WriteJSON(w) + return + } + + var req UnlinkTicketRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + if req.TicketURL == "" { + apierror.BadRequest("ticket_url is required").WriteJSON(w) + return + } + + if err := h.service.UnlinkTicket(r.Context(), tenantID, findingID, req.TicketURL); err != nil { + h.handleError(w, err) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// IncomingJiraWebhook handles POST /api/v1/webhooks/incoming/jira. +// This is a PUBLIC endpoint (no JWT) intended to receive Jira webhook deliveries. +// Tenant routing is via the ?tenant= query param — each Jira project configures one endpoint per tenant. +func (h *JiraWebhookHandler) IncomingJiraWebhook(w http.ResponseWriter, r *http.Request) { + tenantIDStr := r.URL.Query().Get("tenant") + if tenantIDStr == "" { + apierror.BadRequest("tenant query parameter is required").WriteJSON(w) + return + } + + tenantID, err := shared.IDFromString(tenantIDStr) + if err != nil { + apierror.BadRequest("invalid tenant id").WriteJSON(w) + return + } + + var payload app.JiraWebhookPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + apierror.BadRequest("invalid jira webhook payload").WriteJSON(w) + return + } + + if err := h.service.HandleJiraWebhook(r.Context(), tenantID, payload); err != nil { + h.logger.Error("jira webhook processing failed", + "tenant_id", tenantIDStr, + "error", err, + ) + apierror.InternalServerError("webhook processing failed").WriteJSON(w) + return + } + + // Always return 200 — Jira expects a 2xx or it will retry. + w.WriteHeader(http.StatusOK) +} + +// handleError maps domain errors to HTTP responses. +func (h *JiraWebhookHandler) handleError(w http.ResponseWriter, err error) { + switch { + case errors.Is(err, shared.ErrNotFound): + apierror.NotFound("finding not found").WriteJSON(w) + case errors.Is(err, shared.ErrValidation): + apierror.BadRequest(err.Error()).WriteJSON(w) + default: + h.logger.Error("jira handler error", "error", err) + apierror.InternalServerError("internal server error").WriteJSON(w) + } +} diff --git a/internal/infra/http/handler/relationship_suggestion_handler.go b/internal/infra/http/handler/relationship_suggestion_handler.go new file mode 100644 index 00000000..df46c61c --- /dev/null +++ b/internal/infra/http/handler/relationship_suggestion_handler.go @@ -0,0 +1,296 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/openctemio/api/internal/app" + "github.com/openctemio/api/internal/infra/http/middleware" + "github.com/openctemio/api/pkg/apierror" + "github.com/openctemio/api/pkg/domain/relationship" + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/logger" + "github.com/openctemio/api/pkg/pagination" +) + +// RelationshipSuggestionHandler handles relationship suggestion HTTP requests. +type RelationshipSuggestionHandler struct { + service *app.RelationshipSuggestionService + logger *logger.Logger + generateMu sync.Mutex + lastGenerateAt map[string]time.Time // tenant_id -> last generate time +} + +// NewRelationshipSuggestionHandler creates a new RelationshipSuggestionHandler. +func NewRelationshipSuggestionHandler(svc *app.RelationshipSuggestionService, log *logger.Logger) *RelationshipSuggestionHandler { + return &RelationshipSuggestionHandler{ + service: svc, + logger: log, + lastGenerateAt: make(map[string]time.Time), + } +} + +// ============================================================================= +// Response types +// ============================================================================= + +// SuggestionResponse represents a suggestion in API responses. +type SuggestionResponse struct { + ID string `json:"id"` + SourceAssetID string `json:"source_asset_id"` + SourceAssetName string `json:"source_asset_name"` + SourceAssetType string `json:"source_asset_type"` + TargetAssetID string `json:"target_asset_id"` + TargetAssetName string `json:"target_asset_name"` + TargetAssetType string `json:"target_asset_type"` + RelationshipType string `json:"relationship_type"` + Reason string `json:"reason"` + Confidence float64 `json:"confidence"` + Status string `json:"status"` + ReviewedBy *string `json:"reviewed_by,omitempty"` + ReviewedAt *time.Time `json:"reviewed_at,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +// ============================================================================= +// Handlers +// ============================================================================= + +// List handles GET /api/v1/relationships/suggestions +func (h *RelationshipSuggestionHandler) List(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + + query := r.URL.Query() + page := pagination.New( + parseQueryInt(query.Get("page"), 1), + parseQueryInt(query.Get("per_page"), 20), + ) + + search := query.Get("search") + result, err := h.service.ListPending(r.Context(), tenantID, search, page) + if err != nil { + h.handleServiceError(w, err) + return + } + + data := make([]SuggestionResponse, 0, len(result.Data)) + for _, s := range result.Data { + data = append(data, toSuggestionResponse(s)) + } + + response := ListResponse[SuggestionResponse]{ + Data: data, + Total: result.Total, + Page: result.Page, + PerPage: result.PerPage, + TotalPages: result.TotalPages, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) +} + +// Approve handles POST /api/v1/relationships/suggestions/{id}/approve +func (h *RelationshipSuggestionHandler) Approve(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + suggestionID := chi.URLParam(r, "id") + reviewerID := middleware.GetUserID(r.Context()) + + if reviewerID == "" { + apierror.Unauthorized("user ID required").WriteJSON(w) + return + } + + if err := h.service.Approve(r.Context(), tenantID, suggestionID, reviewerID); err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "approved"}) +} + +// Dismiss handles POST /api/v1/relationships/suggestions/{id}/dismiss +func (h *RelationshipSuggestionHandler) Dismiss(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + suggestionID := chi.URLParam(r, "id") + reviewerID := middleware.GetUserID(r.Context()) + + if reviewerID == "" { + apierror.Unauthorized("user ID required").WriteJSON(w) + return + } + + if err := h.service.Dismiss(r.Context(), tenantID, suggestionID, reviewerID); err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "dismissed"}) +} + +// ApproveAll handles POST /api/v1/relationships/suggestions/approve-all +func (h *RelationshipSuggestionHandler) ApproveAll(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + reviewerID := middleware.GetUserID(r.Context()) + + if reviewerID == "" { + apierror.Unauthorized("user ID required").WriteJSON(w) + return + } + + count, err := h.service.ApproveAll(r.Context(), tenantID, reviewerID) + if err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{"status": "approved", "count": count}) +} + +// ApproveBatch handles POST /api/v1/relationships/suggestions/approve-batch +func (h *RelationshipSuggestionHandler) ApproveBatch(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + reviewerID := middleware.GetUserID(r.Context()) + + if reviewerID == "" { + apierror.Unauthorized("user ID required").WriteJSON(w) + return + } + + var req struct { + IDs []string `json:"ids"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + if len(req.IDs) == 0 { + apierror.BadRequest("ids array is required").WriteJSON(w) + return + } + + count, err := h.service.ApproveBatch(r.Context(), tenantID, req.IDs, reviewerID) + if err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{"status": "approved", "count": count}) +} + +// UpdateType handles PATCH /api/v1/relationships/suggestions/{id}/type +func (h *RelationshipSuggestionHandler) UpdateType(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + suggestionID := chi.URLParam(r, "id") + + var req struct { + RelationshipType string `json:"relationship_type"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + + if err := h.service.UpdateRelationshipType(r.Context(), tenantID, suggestionID, req.RelationshipType); err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "updated"}) +} + +// Generate handles POST /api/v1/relationships/suggestions/generate +func (h *RelationshipSuggestionHandler) Generate(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + + // Per-tenant cooldown: 1 generate per 30 seconds + h.generateMu.Lock() + if lastAt, ok := h.lastGenerateAt[tenantID]; ok && time.Since(lastAt) < 30*time.Second { + h.generateMu.Unlock() + apierror.TooManyRequests("Please wait before scanning again").WriteJSON(w) + return + } + h.lastGenerateAt[tenantID] = time.Now() + h.generateMu.Unlock() + + count, err := h.service.GenerateSuggestions(r.Context(), tenantID) + if err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{"status": "generated", "count": count}) +} + +// CountPending handles GET /api/v1/relationships/suggestions/count +func (h *RelationshipSuggestionHandler) CountPending(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + + count, err := h.service.CountPending(r.Context(), tenantID) + if err != nil { + h.handleServiceError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]any{"count": count}) +} + +// ============================================================================= +// Helpers +// ============================================================================= + +func toSuggestionResponse(s *relationship.Suggestion) SuggestionResponse { + resp := SuggestionResponse{ + ID: s.ID().String(), + SourceAssetID: s.SourceAssetID().String(), + SourceAssetName: s.SourceAssetName(), + SourceAssetType: s.SourceAssetType(), + TargetAssetID: s.TargetAssetID().String(), + TargetAssetName: s.TargetAssetName(), + TargetAssetType: s.TargetAssetType(), + RelationshipType: s.RelationshipType(), + Reason: s.Reason(), + Confidence: s.Confidence(), + Status: s.Status(), + ReviewedAt: s.ReviewedAt(), + CreatedAt: s.CreatedAt(), + } + if s.ReviewedBy() != nil { + reviewedByStr := s.ReviewedBy().String() + resp.ReviewedBy = &reviewedByStr + } + return resp +} + +func (h *RelationshipSuggestionHandler) handleServiceError(w http.ResponseWriter, err error) { + switch { + case errors.Is(err, shared.ErrNotFound): + apierror.NotFound("Suggestion").WriteJSON(w) + case errors.Is(err, shared.ErrAlreadyExists): + apierror.Conflict("Suggestion already exists").WriteJSON(w) + case errors.Is(err, shared.ErrValidation): + apierror.BadRequest(err.Error()).WriteJSON(w) + default: + h.logger.Error("service error", "error", err) + apierror.InternalError(err).WriteJSON(w) + } +} diff --git a/internal/infra/http/handler/remediation_campaign_handler.go b/internal/infra/http/handler/remediation_campaign_handler.go index a61f30e7..af652778 100644 --- a/internal/infra/http/handler/remediation_campaign_handler.go +++ b/internal/infra/http/handler/remediation_campaign_handler.go @@ -128,6 +128,31 @@ func (h *RemediationCampaignHandler) UpdateStatus(w http.ResponseWriter, r *http writeJSON(w, http.StatusOK, toRemediationCampaignResp(campaign)) } +// Update updates campaign fields (name, description, priority, tags, due_date). +func (h *RemediationCampaignHandler) Update(w http.ResponseWriter, r *http.Request) { + tenantID := middleware.MustGetTenantID(r.Context()) + id := chi.URLParam(r, "id") + + var req UpdateRemCampaignRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + apierror.BadRequest("invalid request body").WriteJSON(w) + return + } + + campaign, err := h.service.UpdateCampaign(r.Context(), tenantID, id, app.UpdateRemediationCampaignInput{ + Name: req.Name, + Description: req.Description, + Priority: req.Priority, + Tags: req.Tags, + DueDate: req.DueDate, + }) + if err != nil { + h.handleError(w, err) + return + } + writeJSON(w, http.StatusOK, toRemediationCampaignResp(campaign)) +} + // Delete deletes a campaign. func (h *RemediationCampaignHandler) Delete(w http.ResponseWriter, r *http.Request) { tenantID := middleware.MustGetTenantID(r.Context()) @@ -163,6 +188,14 @@ type CreateRemCampaignRequest struct { Tags []string `json:"tags"` } +type UpdateRemCampaignRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Priority *string `json:"priority,omitempty"` + Tags []string `json:"tags,omitempty"` + DueDate *time.Time `json:"due_date,omitempty"` +} + type RemediationCampaignResponse struct { ID string `json:"id"` Name string `json:"name"` diff --git a/internal/infra/http/handler/tenant_handler.go b/internal/infra/http/handler/tenant_handler.go index 483d4c1f..76615125 100644 --- a/internal/infra/http/handler/tenant_handler.go +++ b/internal/infra/http/handler/tenant_handler.go @@ -699,7 +699,7 @@ func (h *TenantHandler) AddMember(w http.ResponseWriter, r *http.Request) { // UpdateMemberRole handles PATCH /api/v1/tenants/{tenant}/members/{memberId} func (h *TenantHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) { - memberID := r.PathValue("memberId") + memberID := r.PathValue("userId") if memberID == "" { apierror.BadRequest("Member ID is required").WriteJSON(w) return @@ -734,7 +734,7 @@ func (h *TenantHandler) UpdateMemberRole(w http.ResponseWriter, r *http.Request) // RemoveMember handles DELETE /api/v1/tenants/{tenant}/members/{memberId} func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { - memberID := r.PathValue("memberId") + memberID := r.PathValue("userId") if memberID == "" { apierror.BadRequest("Member ID is required").WriteJSON(w) return @@ -751,7 +751,7 @@ func (h *TenantHandler) RemoveMember(w http.ResponseWriter, r *http.Request) { // SuspendMember handles POST /api/v1/tenants/{tenant}/members/{memberId}/suspend func (h *TenantHandler) SuspendMember(w http.ResponseWriter, r *http.Request) { - memberID := r.PathValue("memberId") + memberID := r.PathValue("userId") if memberID == "" { apierror.BadRequest("Member ID is required").WriteJSON(w) return @@ -770,7 +770,7 @@ func (h *TenantHandler) SuspendMember(w http.ResponseWriter, r *http.Request) { // ReactivateMember handles POST /api/v1/tenants/{tenant}/members/{memberId}/reactivate func (h *TenantHandler) ReactivateMember(w http.ResponseWriter, r *http.Request) { - memberID := r.PathValue("memberId") + memberID := r.PathValue("userId") if memberID == "" { apierror.BadRequest("Member ID is required").WriteJSON(w) return diff --git a/internal/infra/http/handler/threatintel_handler.go b/internal/infra/http/handler/threatintel_handler.go index cec67508..be0bc50e 100644 --- a/internal/infra/http/handler/threatintel_handler.go +++ b/internal/infra/http/handler/threatintel_handler.go @@ -163,7 +163,7 @@ func (h *ThreatIntelHandler) SetSyncEnabled(w http.ResponseWriter, r *http.Reque // GET /api/v1/threat-intel/enrich/{cve_id} func (h *ThreatIntelHandler) EnrichCVE(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - cveID := chi.URLParam(r, "cve_id") + cveID := chi.URLParam(r, "cveId") enrichment, err := h.service.EnrichCVE(ctx, cveID) if err != nil { @@ -215,7 +215,7 @@ func (h *ThreatIntelHandler) EnrichCVEs(w http.ResponseWriter, r *http.Request) // GET /api/v1/threat-intel/epss/{cve_id} func (h *ThreatIntelHandler) GetEPSSScore(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - cveID := chi.URLParam(r, "cve_id") + cveID := chi.URLParam(r, "cveId") score, err := h.service.GetEPSSScore(ctx, cveID) if err != nil { @@ -234,7 +234,7 @@ func (h *ThreatIntelHandler) GetEPSSScore(w http.ResponseWriter, r *http.Request // GET /api/v1/threat-intel/kev/{cve_id} func (h *ThreatIntelHandler) GetKEVEntry(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - cveID := chi.URLParam(r, "cve_id") + cveID := chi.URLParam(r, "cveId") entry, err := h.service.GetKEVEntry(ctx, cveID) if err != nil { diff --git a/internal/infra/http/handler/tool_handler.go b/internal/infra/http/handler/tool_handler.go index b347f877..d3084a71 100644 --- a/internal/infra/http/handler/tool_handler.go +++ b/internal/infra/http/handler/tool_handler.go @@ -938,7 +938,7 @@ func (h *ToolHandler) ListTenantConfigs(w http.ResponseWriter, r *http.Request) // @Security BearerAuth // @Router /tenant-tools/{tool_id} [get] func (h *ToolHandler) GetTenantConfig(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) config, err := h.service.GetTenantToolConfig(r.Context(), tenantID, toolID) @@ -966,7 +966,7 @@ func (h *ToolHandler) GetTenantConfig(w http.ResponseWriter, r *http.Request) { // @Security BearerAuth // @Router /tenant-tools/{tool_id} [put] func (h *ToolHandler) UpdateTenantConfig(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) userID := middleware.GetUserID(r.Context()) @@ -1008,7 +1008,7 @@ func (h *ToolHandler) UpdateTenantConfig(w http.ResponseWriter, r *http.Request) // @Security BearerAuth // @Router /tenant-tools/{tool_id} [delete] func (h *ToolHandler) DeleteTenantConfig(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) if err := h.service.DeleteTenantToolConfig(r.Context(), tenantID, toolID); err != nil { @@ -1033,7 +1033,7 @@ func (h *ToolHandler) DeleteTenantConfig(w http.ResponseWriter, r *http.Request) // @Security BearerAuth // @Router /tenant-tools/{tool_id}/effective-config [get] func (h *ToolHandler) GetEffectiveConfig(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) config, err := h.service.GetEffectiveToolConfig(r.Context(), tenantID, toolID) @@ -1207,7 +1207,7 @@ func (h *ToolHandler) ListAllTools(w http.ResponseWriter, r *http.Request) { // @Security BearerAuth // @Router /tenant-tools/{tool_id}/with-config [get] func (h *ToolHandler) GetToolWithConfig(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) twc, err := h.service.GetToolWithConfig(r.Context(), tenantID, toolID) @@ -1287,7 +1287,7 @@ func (h *ToolHandler) GetTenantStats(w http.ResponseWriter, r *http.Request) { // @Security BearerAuth // @Router /tool-stats/{tool_id} [get] func (h *ToolHandler) GetToolStats(w http.ResponseWriter, r *http.Request) { - toolID := chi.URLParam(r, "tool_id") + toolID := chi.URLParam(r, "toolId") tenantID := middleware.GetTenantID(r.Context()) days := parseQueryInt(r.URL.Query().Get("days"), 30) diff --git a/internal/infra/http/handler/vulnerability_handler.go b/internal/infra/http/handler/vulnerability_handler.go index 565587b2..092e95d6 100644 --- a/internal/infra/http/handler/vulnerability_handler.go +++ b/internal/infra/http/handler/vulnerability_handler.go @@ -1160,7 +1160,7 @@ func (h *VulnerabilityHandler) GetVulnerability(w http.ResponseWriter, r *http.R // @Failure 404 {object} map[string]string // @Router /vulnerabilities/cve/{cve_id} [get] func (h *VulnerabilityHandler) GetVulnerabilityByCVE(w http.ResponseWriter, r *http.Request) { - cveID := r.PathValue("cve_id") + cveID := r.PathValue("cveId") if cveID == "" { apierror.BadRequest("CVE ID is required").WriteJSON(w) return @@ -2205,7 +2205,7 @@ func (h *VulnerabilityHandler) UpdateComment(w http.ResponseWriter, r *http.Requ } findingID := r.PathValue("id") - commentID := r.PathValue("comment_id") + commentID := r.PathValue("commentId") if findingID == "" || commentID == "" { apierror.BadRequest("Finding ID and Comment ID are required").WriteJSON(w) return @@ -2244,7 +2244,7 @@ func (h *VulnerabilityHandler) DeleteComment(w http.ResponseWriter, r *http.Requ userID = localUserID.String() } - commentID := r.PathValue("comment_id") + commentID := r.PathValue("commentId") if commentID == "" { apierror.BadRequest("Comment ID is required").WriteJSON(w) return diff --git a/internal/infra/http/routes/admin.go b/internal/infra/http/routes/admin.go index be69b078..5fb3b804 100644 --- a/internal/infra/http/routes/admin.go +++ b/internal/infra/http/routes/admin.go @@ -35,7 +35,7 @@ func registerAdminRoutes( // Admin user management (requires super_admin role) if h.AdminUser != nil { - router.Group("/api/v1/admin/admins", func(r Router) { + router.Group("/api/v1/admin/users", func(r Router) { r.GET("/", h.AdminUser.List) r.POST("/", h.AdminUser.Create) r.GET("/{id}", h.AdminUser.Get) diff --git a/internal/infra/http/routes/assets.go b/internal/infra/http/routes/assets.go index 8a51d8e5..480bdc01 100644 --- a/internal/infra/http/routes/assets.go +++ b/internal/infra/http/routes/assets.go @@ -31,7 +31,7 @@ func registerAssetRoutes( r.GET("/tags", h.ListTags, middleware.Require(permission.AssetsRead)) // Bulk operations (must be before /{id} patterns to avoid route conflicts) - r.POST("/bulk-sync", h.BulkSync, middleware.Require(permission.AssetsWrite)) + r.POST("/bulk/sync", h.BulkSync, middleware.Require(permission.AssetsWrite)) r.POST("/bulk/status", h.BulkUpdateStatus, middleware.Require(permission.AssetsWrite)) // Read operations @@ -79,8 +79,8 @@ func registerAssetOwnerRoutes( router.Group("/api/v1/assets/{id}/owners", func(r Router) { r.GET("/", h.ListOwners, middleware.Require(permission.AssetsRead)) r.POST("/", h.AddOwner, middleware.Require(permission.AssetsWrite)) - r.PUT("/{ownerID}", h.UpdateOwner, middleware.Require(permission.AssetsWrite)) - r.DELETE("/{ownerID}", h.RemoveOwner, middleware.Require(permission.AssetsDelete)) + r.PUT("/{ownerId}", h.UpdateOwner, middleware.Require(permission.AssetsWrite)) + r.DELETE("/{ownerId}", h.RemoveOwner, middleware.Require(permission.AssetsDelete)) }, middlewares...) } @@ -277,13 +277,13 @@ func registerFindingSourceRoutes( tenantMiddlewares := buildTokenTenantMiddlewares(authMiddleware, userSyncMiddleware) // Finding Source Category routes (read-only) - router.Group("/api/v1/config/finding-sources/categories", func(r Router) { + router.Group("/api/v1/finding-sources/categories", func(r Router) { r.GET("/", h.ListCategories, middleware.Require(permission.FindingsRead)) r.GET("/{categoryId}", h.GetCategory, middleware.Require(permission.FindingsRead)) }, tenantMiddlewares...) // Finding Source routes (read-only) - router.Group("/api/v1/config/finding-sources", func(r Router) { + router.Group("/api/v1/finding-sources", func(r Router) { r.GET("/", h.ListFindingSources, middleware.Require(permission.FindingsRead)) r.GET("/code/{code}", h.GetFindingSourceByCode, middleware.Require(permission.FindingsRead)) r.GET("/{id}", h.GetFindingSource, middleware.Require(permission.FindingsRead)) @@ -305,6 +305,9 @@ func registerAttackSurfaceRoutes( // Attack Surface routes router.Group("/api/v1/attack-surface", func(r Router) { r.GET("/stats", h.GetStats, middleware.Require(permission.AssetsRead)) + // Attack path scoring — BFS reachability analysis from public entry points. + // Returns top assets ranked by composite path score (reachability × risk × criticality). + r.GET("/attack-paths", h.GetAttackPaths, middleware.Require(permission.AssetsRead)) }, tenantMiddlewares...) } @@ -406,6 +409,30 @@ func registerAssetRelationshipRoutes( }, tenantMiddlewares...) } +// registerRelationshipSuggestionRoutes registers relationship suggestion endpoints. +// Suggestions are auto-generated relationship recommendations based on asset analysis. +func registerRelationshipSuggestionRoutes( + router Router, + h *handler.RelationshipSuggestionHandler, + authMiddleware Middleware, + userSyncMiddleware Middleware, +) { + // Build tenant middleware chain from JWT token + tenantMiddlewares := buildTokenTenantMiddlewares(authMiddleware, userSyncMiddleware) + + // Suggestion routes under /api/v1/relationships/suggestions + router.Group("/api/v1/relationships/suggestions", func(r Router) { + r.GET("/", h.List, middleware.Require(permission.AssetsRead)) + r.GET("/count", h.CountPending, middleware.Require(permission.AssetsRead)) + r.POST("/generate", h.Generate, middleware.Require(permission.AssetsWrite)) + r.POST("/approve-all", h.ApproveAll, middleware.Require(permission.AssetsWrite)) + r.POST("/approve-batch", h.ApproveBatch, middleware.Require(permission.AssetsWrite)) + r.POST("/{id}/approve", h.Approve, middleware.Require(permission.AssetsWrite)) + r.POST("/{id}/dismiss", h.Dismiss, middleware.Require(permission.AssetsWrite)) + r.PATCH("/{id}/type", h.UpdateType, middleware.Require(permission.AssetsWrite)) + }, tenantMiddlewares...) +} + // registerAssetStateHistoryRoutes registers asset state history endpoints. // State history tracks changes for audit, compliance, and shadow IT detection. // Part of the CTEM Discovery phase. diff --git a/internal/infra/http/routes/business_unit.go b/internal/infra/http/routes/business_unit.go index 56dc3032..04c02e3f 100644 --- a/internal/infra/http/routes/business_unit.go +++ b/internal/infra/http/routes/business_unit.go @@ -19,6 +19,7 @@ func registerBusinessUnitRoutes( r.GET("/", h.List, middleware.Require(permission.AssetsRead)) r.POST("/", h.Create, middleware.Require(permission.AssetsWrite)) r.GET("/{id}", h.Get, middleware.Require(permission.AssetsRead)) + r.PUT("/{id}", h.Update, middleware.Require(permission.AssetsWrite)) r.DELETE("/{id}", h.Delete, middleware.Require(permission.AssetsWrite)) r.POST("/{id}/assets", h.AddAsset, middleware.Require(permission.AssetsWrite)) r.DELETE("/{id}/assets/{assetId}", h.RemoveAsset, middleware.Require(permission.AssetsWrite)) diff --git a/internal/infra/http/routes/exposure.go b/internal/infra/http/routes/exposure.go index c1e23174..c6b2dc46 100644 --- a/internal/infra/http/routes/exposure.go +++ b/internal/infra/http/routes/exposure.go @@ -72,16 +72,16 @@ func registerThreatIntelRoutes( r.PATCH("/sync/{source}", h.SetSyncEnabled, middleware.Require(permission.VulnerabilitiesWrite)) // CVE enrichment (combine EPSS + KEV data) - r.GET("/enrich/{cve_id}", h.EnrichCVE, middleware.Require(permission.VulnerabilitiesRead)) + r.GET("/enrich/{cveId}", h.EnrichCVE, middleware.Require(permission.VulnerabilitiesRead)) r.POST("/enrich", h.EnrichCVEs, middleware.Require(permission.VulnerabilitiesRead)) - // EPSS scores (must have stats before {cve_id} to avoid route conflicts) + // EPSS scores (must have stats before {cveId} to avoid route conflicts) r.GET("/epss/stats", h.GetEPSSStats, middleware.Require(permission.VulnerabilitiesRead)) - r.GET("/epss/{cve_id}", h.GetEPSSScore, middleware.Require(permission.VulnerabilitiesRead)) + r.GET("/epss/{cveId}", h.GetEPSSScore, middleware.Require(permission.VulnerabilitiesRead)) - // KEV catalog (must have stats before {cve_id} to avoid route conflicts) + // KEV catalog (must have stats before {cveId} to avoid route conflicts) r.GET("/kev/stats", h.GetKEVStats, middleware.Require(permission.VulnerabilitiesRead)) - r.GET("/kev/{cve_id}", h.GetKEVEntry, middleware.Require(permission.VulnerabilitiesRead)) + r.GET("/kev/{cveId}", h.GetKEVEntry, middleware.Require(permission.VulnerabilitiesRead)) }, baseMiddlewares...) } @@ -152,6 +152,7 @@ func registerVulnerabilityRoutes( router Router, h *handler.VulnerabilityHandler, findingActionsHandler *handler.FindingActionsHandler, + jiraHandler *handler.JiraWebhookHandler, authMiddleware Middleware, userSyncMiddleware Middleware, ) { @@ -163,7 +164,7 @@ func registerVulnerabilityRoutes( // Read operations r.GET("/", h.ListVulnerabilities, middleware.Require(permission.VulnerabilitiesRead)) r.GET("/{id}", h.GetVulnerability, middleware.Require(permission.VulnerabilitiesRead)) - r.GET("/cve/{cve_id}", h.GetVulnerabilityByCVE, middleware.Require(permission.VulnerabilitiesRead)) + r.GET("/cve/{cveId}", h.GetVulnerabilityByCVE, middleware.Require(permission.VulnerabilitiesRead)) // Write operations (admin only) r.POST("/", h.CreateVulnerability, middleware.Require(permission.VulnerabilitiesWrite)) @@ -219,12 +220,24 @@ func registerVulnerabilityRoutes( r.PATCH("/{id}/triage", h.TriageFinding, middleware.Require(permission.FindingsWrite)) r.POST("/{id}/verify", h.VerifyFinding, middleware.Require(permission.FindingsWrite)) + // Verification scan automation: trigger a targeted scan on the finding's asset + // (only available when finding actions handler is wired) + if findingActionsHandler != nil { + r.POST("/{id}/request-verification", findingActionsHandler.RequestVerificationScan, middleware.Require(permission.FindingsWrite)) + } + // Tags r.PUT("/{id}/tags", h.SetFindingTags, middleware.Require(permission.FindingsWrite)) // Data flows (attack paths / taint tracking) r.GET("/{id}/dataflows", h.GetFindingDataFlows, middleware.Require(permission.FindingsRead)) + // Jira ticket linking — store/remove Jira ticket references on findings + if jiraHandler != nil { + r.POST("/{id}/link-ticket", jiraHandler.LinkTicket, middleware.Require(permission.FindingsWrite)) + r.DELETE("/{id}/link-ticket", jiraHandler.UnlinkTicket, middleware.Require(permission.FindingsWrite)) + } + // Delete operations r.DELETE("/{id}", h.DeleteFinding, middleware.Require(permission.FindingsDelete)) }, tenantMiddlewares...) @@ -238,8 +251,8 @@ func registerVulnerabilityRoutes( router.Group("/api/v1/findings/{id}/comments", func(r Router) { r.GET("/", h.ListComments, middleware.Require(permission.FindingsRead)) r.POST("/", h.AddComment, middleware.Require(permission.FindingsWrite)) - r.PUT("/{comment_id}", h.UpdateComment, middleware.Require(permission.FindingsWrite)) - r.DELETE("/{comment_id}", h.DeleteComment, middleware.Require(permission.FindingsWrite)) + r.PUT("/{commentId}", h.UpdateComment, middleware.Require(permission.FindingsWrite)) + r.DELETE("/{commentId}", h.DeleteComment, middleware.Require(permission.FindingsWrite)) }, tenantMiddlewares...) // Finding approval routes - tenant from JWT token @@ -280,7 +293,7 @@ func registerFindingActivityRoutes( // Finding activity routes - tenant from JWT token router.Group("/api/v1/findings/{id}/activities", func(r Router) { r.GET("/", h.ListActivities, middleware.Require(permission.FindingsRead)) - r.GET("/{activity_id}", h.GetActivity, middleware.Require(permission.FindingsRead)) + r.GET("/{activityId}", h.GetActivity, middleware.Require(permission.FindingsRead)) // Note: Activities are created automatically via service hooks, not via direct API // Real-time updates are delivered via WebSocket channel: finding:{id} }, tenantMiddlewares...) @@ -295,8 +308,8 @@ func registerFindingActivityRoutes( // - POST /api/v1/findings/ai-triage/bulk - Bulk triage multiple findings (rate-limited) // - GET /api/v1/findings/{id}/ai-triage - Get latest triage result // - GET /api/v1/findings/{id}/ai-triage/history - Get triage history -// - GET /api/v1/findings/{id}/ai-triage/{triage_id} - Get specific triage result -// - GET /api/v1/ai-triage/config - Get AI configuration info +// - GET /api/v1/findings/{id}/ai-triage/{triageId} - Get specific triage result +// - GET /api/v1/findings/ai-triage/config - Get AI configuration info func registerAITriageRoutes( router Router, h *handler.AITriageHandler, @@ -316,14 +329,14 @@ func registerAITriageRoutes( // AI triage routes - tenant from JWT token router.Group("/api/v1/findings/{id}/ai-triage", func(r Router) { - // Get latest triage result (must be before /{triage_id} to avoid conflicts) + // Get latest triage result (must be before /{triageId} to avoid conflicts) r.GET("/", h.GetTriageResult, middleware.Require(permission.FindingsRead)) - // Get triage history (must be before /{triage_id} to avoid conflicts) + // Get triage history (must be before /{triageId} to avoid conflicts) r.GET("/history", h.ListTriageHistory, middleware.Require(permission.FindingsRead)) // Get specific triage result by ID - r.GET("/{triage_id}", h.GetTriageResultByID, middleware.Require(permission.FindingsRead)) + r.GET("/{triageId}", h.GetTriageResultByID, middleware.Require(permission.FindingsRead)) }, tenantMiddlewares...) // Trigger AI triage for a finding (rate-limited) @@ -336,5 +349,5 @@ func registerAITriageRoutes( append(postMiddlewares, middleware.Require(permission.FindingsWrite))...) // AI triage config endpoint - returns current AI mode, provider, model - router.GET("/api/v1/ai-triage/config", h.GetConfig, tenantMiddlewares...) + router.GET("/api/v1/findings/ai-triage/config", h.GetConfig, tenantMiddlewares...) } diff --git a/internal/infra/http/routes/misc.go b/internal/infra/http/routes/misc.go index 1b2f70fe..90705e1b 100644 --- a/internal/infra/http/routes/misc.go +++ b/internal/infra/http/routes/misc.go @@ -327,3 +327,18 @@ func registerWebhookRoutes( r.GET("/{id}/deliveries", h.ListDeliveries, middleware.Require(permission.WebhooksRead)) }, tenantMiddlewares...) } + +// registerIncomingWebhookRoutes registers public incoming webhook endpoints. +// These endpoints are NOT protected by JWT — they are called by external services (e.g. Jira). +// Tenant routing is done via a ?tenant= query parameter that each external service configures. +func registerIncomingWebhookRoutes( + router Router, + jiraHandler *handler.JiraWebhookHandler, +) { + if jiraHandler == nil { + return + } + // Public endpoint — no auth middleware. + // Jira requires a 200 response on delivery, so we always accept and process asynchronously. + router.POST("/api/v1/webhooks/incoming/jira", jiraHandler.IncomingJiraWebhook) +} diff --git a/internal/infra/http/routes/remediation.go b/internal/infra/http/routes/remediation.go index 4523bd56..224935a2 100644 --- a/internal/infra/http/routes/remediation.go +++ b/internal/infra/http/routes/remediation.go @@ -19,6 +19,7 @@ func registerRemediationCampaignRoutes( r.GET("/", h.List, middleware.Require(permission.FindingsRead)) r.POST("/", h.Create, middleware.Require(permission.FindingsWrite)) r.GET("/{id}", h.Get, middleware.Require(permission.FindingsRead)) + r.PATCH("/{id}", h.Update, middleware.Require(permission.FindingsWrite)) r.PATCH("/{id}/status", h.UpdateStatus, middleware.Require(permission.FindingsWrite)) r.DELETE("/{id}", h.Delete, middleware.Require(permission.FindingsWrite)) }, tenantMiddlewares...) diff --git a/internal/infra/http/routes/routes.go b/internal/infra/http/routes/routes.go index 0a884709..abc764b7 100644 --- a/internal/infra/http/routes/routes.go +++ b/internal/infra/http/routes/routes.go @@ -71,7 +71,8 @@ type Handlers struct { // CTEM Discovery handlers AssetService *handler.AssetServiceHandler // nil if not initialized (no database) AssetStateHistory *handler.AssetStateHistoryHandler // nil if not initialized (no database) - AssetRelationship *handler.AssetRelationshipHandler // nil if not initialized (no database) + AssetRelationship *handler.AssetRelationshipHandler // nil if not initialized (no database) + RelationshipSuggestion *handler.RelationshipSuggestionHandler // nil if not initialized (no database) // Access Control handlers Group *handler.GroupHandler // nil if not initialized (no database) @@ -85,6 +86,9 @@ type Handlers struct { // Finding Lifecycle (closed-loop: fix_applied → verified → resolved) FindingActions *handler.FindingActionsHandler // nil if not initialized (no database) + // Jira Bidirectional Sync (link tickets to findings + receive Jira webhooks) + JiraWebhook *handler.JiraWebhookHandler // nil if not initialized (no database) + // Pentest Campaign Management handlers Pentest *handler.PentestHandler // nil if not initialized (no database) PentestCampaignRoleQry middleware.CampaignRoleQuerier // Campaign role resolver for RBAC middleware @@ -270,11 +274,19 @@ func Register( registerAssetRelationshipRoutes(router, h.AssetRelationship, authMiddleware, userSync) } + // Relationship Suggestion routes (auto-generated relationship recommendations) + if h.RelationshipSuggestion != nil { + registerRelationshipSuggestionRoutes(router, h.RelationshipSuggestion, authMiddleware, userSync) + } + // Vulnerability routes (global) and Finding routes (tenant from JWT token) if h.Vulnerability != nil { - registerVulnerabilityRoutes(router, h.Vulnerability, h.FindingActions, authMiddleware, userSync) + registerVulnerabilityRoutes(router, h.Vulnerability, h.FindingActions, h.JiraWebhook, authMiddleware, userSync) } + // Incoming Jira webhook — public endpoint (no JWT), Jira POSTs status changes here. + registerIncomingWebhookRoutes(router, h.JiraWebhook) + // Initialize finding activity rate limiter to prevent enumeration and DoS var activityRateLimiter *middleware.FindingActivityRateLimiter if cfg.RateLimit.Enabled { diff --git a/internal/infra/http/routes/scanning.go b/internal/infra/http/routes/scanning.go index ccded636..ee05e212 100644 --- a/internal/infra/http/routes/scanning.go +++ b/internal/infra/http/routes/scanning.go @@ -58,7 +58,7 @@ func registerAgentRoutes( // Supported formats: CTIS (native), SARIF (industry standard), Recon (discovery data), Chunk (for large reports) // All ingest endpoints support compressed request bodies (Content-Encoding: gzip or zstd) // Ingest endpoints use a 50MB body limit (vs 10MB default) for large scan reports - r.POST("/ingest", ingestHandler.IngestCTIS, ingestBodyLimit, decompressMiddleware) + r.POST("/ingest", ingestHandler.IngestCTIS, ingestBodyLimit, decompressMiddleware) // Primary CTIS ingest endpoint r.POST("/ingest/check", ingestHandler.CheckFingerprints, ingestBodyLimit, decompressMiddleware) r.POST("/ingest/sarif", ingestHandler.IngestSARIF, ingestBodyLimit, decompressMiddleware) r.POST("/ingest/ctis", ingestHandler.IngestCTIS, ingestBodyLimit, decompressMiddleware) @@ -263,31 +263,31 @@ func registerToolRoutes( // Tenant Tool Config routes (tenant-scoped) router.Group("/api/v1/tenant-tools", func(r Router) { - // Bulk operations (must be before /{tool_id} to avoid route conflicts) - r.POST("/bulk-enable", h.BulkEnable, middleware.Require(permission.TenantToolsWrite)) - r.POST("/bulk-disable", h.BulkDisable, middleware.Require(permission.TenantToolsWrite)) + // Bulk operations (must be before /{toolId} to avoid route conflicts) + r.POST("/bulk/enable", h.BulkEnable, middleware.Require(permission.TenantToolsWrite)) + r.POST("/bulk/disable", h.BulkDisable, middleware.Require(permission.TenantToolsWrite)) - // List all tools with tenant-specific enabled status (must be before /{tool_id}) + // List all tools with tenant-specific enabled status (must be before /{toolId}) r.GET("/all-tools", h.ListAllTools, middleware.Require(permission.TenantToolsRead)) // Read operations r.GET("/", h.ListTenantConfigs, middleware.Require(permission.TenantToolsRead)) - r.GET("/{tool_id}", h.GetTenantConfig, middleware.Require(permission.TenantToolsRead)) - r.GET("/{tool_id}/effective-config", h.GetEffectiveConfig, middleware.Require(permission.TenantToolsRead)) - r.GET("/{tool_id}/with-config", h.GetToolWithConfig, middleware.Require(permission.TenantToolsRead)) + r.GET("/{toolId}", h.GetTenantConfig, middleware.Require(permission.TenantToolsRead)) + r.GET("/{toolId}/effective-config", h.GetEffectiveConfig, middleware.Require(permission.TenantToolsRead)) + r.GET("/{toolId}/with-config", h.GetToolWithConfig, middleware.Require(permission.TenantToolsRead)) // Write operations - r.PUT("/{tool_id}", h.UpdateTenantConfig, middleware.Require(permission.TenantToolsWrite)) + r.PUT("/{toolId}", h.UpdateTenantConfig, middleware.Require(permission.TenantToolsWrite)) // Delete operations - r.DELETE("/{tool_id}", h.DeleteTenantConfig, middleware.Require(permission.TenantToolsDelete)) - }, tenantMiddlewares...) + r.DELETE("/{toolId}", h.DeleteTenantConfig, middleware.Require(permission.TenantToolsDelete)) - // Tool Stats routes (tenant-scoped) - router.Group("/api/v1/tool-stats", func(r Router) { - r.GET("/", h.GetTenantStats, middleware.Require(permission.TenantToolsRead)) - r.GET("/{tool_id}", h.GetToolStats, middleware.Require(permission.TenantToolsRead)) + // Stats (consolidated from /tool-stats) + r.GET("/stats", h.GetTenantStats, middleware.Require(permission.TenantToolsRead)) + r.GET("/stats/{toolId}", h.GetToolStats, middleware.Require(permission.TenantToolsRead)) }, tenantMiddlewares...) + + // /tool-stats removed — use /tenant-tools/stats } // registerToolCategoryRoutes registers tool category endpoints. @@ -373,25 +373,20 @@ func registerScanRoutes( // Build tenant middleware chain from JWT token tenantMiddlewares := buildTokenTenantMiddlewares(authMiddleware, userSyncMiddleware) - // Quick scan endpoint - separate from /scans to avoid conflict - router.Group("/api/v1/quick-scan", func(r Router) { - // Apply rate limiting to quick scans (stricter) - if triggerRateLimiter != nil { - r.POST("/", h.QuickScan, middleware.Require(permission.ScansWrite), triggerRateLimiter.QuickScanMiddleware()) - } else { - r.POST("/", h.QuickScan, middleware.Require(permission.ScansWrite)) - } - }, tenantMiddlewares...) - - // Scan management overview stats - router.Group("/api/v1/scan-management", func(r Router) { - r.GET("/stats", h.GetOverviewStats, middleware.Require(permission.ScansRead)) - }, tenantMiddlewares...) + // /quick-scan and /scan-management removed — use /scans/quick and /scans/overview-stats // Scan routes - tenant from JWT token router.Group("/api/v1/scans", func(r Router) { // Stats endpoint (must be before /{id} to avoid matching) r.GET("/stats", h.GetStats, middleware.Require(permission.ScansRead)) + // Overview stats (consolidated from /scan-management/stats) + r.GET("/overview-stats", h.GetOverviewStats, middleware.Require(permission.ScansRead)) + // Quick scan (consolidated from /quick-scan) + if triggerRateLimiter != nil { + r.POST("/quick", h.QuickScan, middleware.Require(permission.ScansWrite), triggerRateLimiter.QuickScanMiddleware()) + } else { + r.POST("/quick", h.QuickScan, middleware.Require(permission.ScansWrite)) + } // Bulk operations (must be before /{id} to avoid matching) r.POST("/bulk/activate", h.BulkActivate, middleware.Require(permission.ScansWrite)) diff --git a/internal/infra/http/routes/tenant.go b/internal/infra/http/routes/tenant.go index 6b138089..30e1cc78 100644 --- a/internal/infra/http/routes/tenant.go +++ b/internal/infra/http/routes/tenant.go @@ -62,10 +62,10 @@ func registerTenantRoutes( // Admin operations (admin+) r.PATCH("/", h.Update, middleware.RequireTeamAdmin()) r.POST("/members", h.AddMember, middleware.RequireTeamAdmin()) - r.PATCH("/members/{memberId}", h.UpdateMemberRole, middleware.RequireTeamAdmin()) - r.POST("/members/{memberId}/suspend", h.SuspendMember, middleware.RequireTeamAdmin()) - r.POST("/members/{memberId}/reactivate", h.ReactivateMember, middleware.RequireTeamAdmin()) - r.DELETE("/members/{memberId}", h.RemoveMember, middleware.RequireTeamAdmin()) + r.PATCH("/members/{userId}", h.UpdateMemberRole, middleware.RequireTeamAdmin()) + r.POST("/members/{userId}/suspend", h.SuspendMember, middleware.RequireTeamAdmin()) + r.POST("/members/{userId}/reactivate", h.ReactivateMember, middleware.RequireTeamAdmin()) + r.DELETE("/members/{userId}", h.RemoveMember, middleware.RequireTeamAdmin()) r.POST("/invitations", h.CreateInvitation, middleware.RequireTeamAdmin()) r.POST("/invitations/{invitationId}/resend", h.ResendInvitation, middleware.RequireTeamAdmin()) r.DELETE("/invitations/{invitationId}", h.DeleteInvitation, middleware.RequireTeamAdmin()) diff --git a/internal/infra/postgres/asset_relationship_repository.go b/internal/infra/postgres/asset_relationship_repository.go index d851184c..1c2d48d9 100644 --- a/internal/infra/postgres/asset_relationship_repository.go +++ b/internal/infra/postgres/asset_relationship_repository.go @@ -499,4 +499,36 @@ func (r *AssetRelationshipRepository) applyDirectionFilter(conditions string, qu return " AND FALSE" } +// ListAllEdges fetches every relationship for the tenant as lightweight graph +// edges. Used exclusively by attack path scoring which needs the full directed +// graph in-memory. Columns are minimal to keep the query fast. +func (r *AssetRelationshipRepository) ListAllEdges(ctx context.Context, tenantID shared.ID) ([]asset.RelationshipEdge, error) { + const query = ` + SELECT source_asset_id, target_asset_id, relationship_type, impact_weight + FROM asset_relationships + WHERE tenant_id = $1 + ORDER BY created_at + ` + rows, err := r.db.QueryContext(ctx, query, tenantID.String()) + if err != nil { + return nil, fmt.Errorf("list all edges: %w", err) + } + defer func() { _ = rows.Close() }() + + var edges []asset.RelationshipEdge + for rows.Next() { + var e asset.RelationshipEdge + var relType string + if scanErr := rows.Scan(&e.SourceAssetID, &e.TargetAssetID, &relType, &e.ImpactWeight); scanErr != nil { + return nil, fmt.Errorf("scan edge: %w", scanErr) + } + e.Type = asset.RelationshipType(relType) + edges = append(edges, e) + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("iterate edges: %w", err) + } + return edges, nil +} + // Note: nullString is defined in helpers.go diff --git a/internal/infra/postgres/asset_repository.go b/internal/infra/postgres/asset_repository.go index 44c2b5e9..d8a3fce2 100644 --- a/internal/infra/postgres/asset_repository.go +++ b/internal/infra/postgres/asset_repository.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strings" "time" @@ -285,7 +286,7 @@ func (r *AssetRepository) Update(ctx context.Context, a *asset.Asset) error { compliance_scope = $26, data_classification = $27, pii_data_exposed = $28, phi_data_exposed = $29, regulatory_owner_id = $30, is_internet_accessible = $31, exposure_changed_at = $32, last_exposure_level = $33, last_seen = $34, updated_at = $35 - WHERE id = $1 + WHERE id = $1 AND tenant_id = $36 ` updateOwnerRef := sql.NullString{String: a.OwnerRef(), Valid: a.OwnerRef() != ""} @@ -325,6 +326,7 @@ func (r *AssetRepository) Update(ctx context.Context, a *asset.Asset) error { nullString(string(a.LastExposureLevel())), a.LastSeen(), a.UpdatedAt(), + a.TenantID().String(), ) if err != nil { @@ -1059,12 +1061,19 @@ func (r *AssetRepository) UpsertBatch(ctx context.Context, assets []*asset.Asset // ListDistinctTags returns distinct tags across all assets for a tenant. // Supports prefix filtering for autocomplete and a limit for result size. -func (r *AssetRepository) ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, limit int) ([]string, error) { +func (r *AssetRepository) ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, types []string, limit int) ([]string, error) { query := `SELECT DISTINCT tag FROM assets, unnest(tags) AS tag WHERE tenant_id = $1` args := []any{tenantID.String()} + argIdx := 2 + + if len(types) > 0 { + query += fmt.Sprintf(` AND asset_type = ANY($%d)`, argIdx) + args = append(args, pq.Array(types)) + argIdx++ + } if prefix != "" { - query += ` AND tag ILIKE $2` + query += fmt.Sprintf(` AND tag ILIKE $%d`, argIdx) args = append(args, escapeLikePattern(prefix)+"%") } @@ -1230,14 +1239,15 @@ func (r *AssetRepository) GetAverageRiskScore(ctx context.Context, tenantID shar // This version collapses everything into one query using a CTE + UNION ALL, // trading slightly more complex SQL for an 83% reduction in DB round-trips. // PostgreSQL plans a single scan of the filtered CTE for all aggregates. -func (r *AssetRepository) GetAggregateStats(ctx context.Context, tenantID shared.ID, types []string, tags []string, subType string) (*asset.AggregateStats, error) { +func (r *AssetRepository) GetAggregateStats(ctx context.Context, tenantID shared.ID, types []string, tags []string, subType string, countByFields ...string) (*asset.AggregateStats, error) { stats := &asset.AggregateStats{ - ByType: make(map[string]int), - BySubType: make(map[string]int), - ByStatus: make(map[string]int), - ByCriticality: make(map[string]int), - ByScope: make(map[string]int), - ByExposure: make(map[string]int), + ByType: make(map[string]int), + BySubType: make(map[string]int), + ByStatus: make(map[string]int), + ByCriticality: make(map[string]int), + ByScope: make(map[string]int), + ByExposure: make(map[string]int), + MetadataCounts: make(map[string]map[string]int), } // Build the WHERE clause once. @@ -1305,6 +1315,26 @@ SELECT category, key, value FROM ( ) sub `, filterClause) + // Append metadata count queries for requested JSONB property fields. + // Each field adds a UNION ALL that groups by the property value. + // Only alphanumeric + underscore field names are allowed (SQL injection safe). + validField := regexp.MustCompile(`^[a-z][a-z0-9_]{0,49}$`) + for _, field := range countByFields { + if !validField.MatchString(field) { + continue + } + // Insert before the closing ") sub" by replacing it + query = strings.TrimSuffix(strings.TrimSpace(query), ") sub") + query += fmt.Sprintf(` + UNION ALL + SELECT 'meta:%s', COALESCE(a.properties->>'%s', 'null'), COUNT(*)::float8 + FROM filtered a + WHERE a.properties ? '%s' + GROUP BY a.properties->>'%s' +) sub +`, field, field, field, field) + } + rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("failed to get aggregate stats: %w", err) @@ -1340,6 +1370,15 @@ SELECT category, key, value FROM ( stats.ByScope[key] = int(value) case "exposure": stats.ByExposure[key] = int(value) + default: + // Handle dynamic metadata counts: "meta:field_name" + if strings.HasPrefix(category, "meta:") { + field := strings.TrimPrefix(category, "meta:") + if stats.MetadataCounts[field] == nil { + stats.MetadataCounts[field] = make(map[string]int) + } + stats.MetadataCounts[field][key] = int(value) + } } } if err := rows.Err(); err != nil { @@ -1350,94 +1389,102 @@ SELECT category, key, value FROM ( } // GetPropertyFacets returns distinct JSONB property keys and their top values. -// Uses a 2-step approach: first get keys, then get top values per key. +// Uses a single query that expands JSONB keys and values together, then groups +// in Go — replacing the previous 1+N query pattern. func (r *AssetRepository) GetPropertyFacets(ctx context.Context, tenantID shared.ID, types []string, subType string) ([]asset.PropertyFacet, error) { - // Build WHERE clause - where := "WHERE a.tenant_id = $1 AND a.properties IS NOT NULL AND a.properties != '{}'::jsonb" + // Build optional extra filter clauses (applied inside the sub-select). + extraWhere := "" args := []any{tenantID.String()} idx := 2 if len(types) > 0 { - where += fmt.Sprintf(" AND a.asset_type = ANY($%d::text[])", idx) + extraWhere += fmt.Sprintf(" AND a.asset_type = ANY($%d::text[])", idx) args = append(args, pq.Array(types)) idx++ } if subType != "" { - where += fmt.Sprintf(" AND a.sub_type = $%d", idx) + extraWhere += fmt.Sprintf(" AND a.sub_type = $%d", idx) args = append(args, subType) - idx++ + // idx++ — not needed after the last param } - // Step 1: Get top property keys with counts (skip array/object values, only scalar) - keysQuery := fmt.Sprintf(` - SELECT key, COUNT(*) as cnt + // Single query: expand every JSONB key/value pair per asset, then aggregate. + // jsonb_object_keys() is a set-returning function; using it twice in the same + // SELECT causes a parallel scan that PostgreSQL evaluates consistently. + // We filter out known array/object keys and keep only scalar values. + query := fmt.Sprintf(` + SELECT key, val, COUNT(*) AS cnt FROM ( - SELECT jsonb_object_keys(a.properties) as key - FROM assets a %s - ) keys + SELECT + jsonb_object_keys(a.properties) AS key, + a.properties ->> jsonb_object_keys(a.properties) AS val + FROM assets a + WHERE a.tenant_id = $1 + AND a.properties IS NOT NULL + AND a.properties != '{}'::jsonb + %s + ) sub WHERE key NOT IN ('ip_addresses', 'dns_records', 'ports', 'technologies', 'interfaces', 'tags') - GROUP BY key - HAVING COUNT(*) >= 2 - ORDER BY cnt DESC - LIMIT 15 - `, where) + AND val IS NOT NULL + AND val != '' + GROUP BY key, val + ORDER BY key, cnt DESC + `, extraWhere) - keyRows, err := r.db.QueryContext(ctx, keysQuery, args...) + rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("failed to get property keys: %w", err) + return nil, fmt.Errorf("failed to get property facets: %w", err) } - defer func() { _ = keyRows.Close() }() + defer func() { _ = rows.Close() }() - type keyInfo struct { - key string - count int + // Group results by key in insertion order; track per-key asset count. + type facetAccum struct { + values []string + totalCount int // sum of per-value counts (≈ asset count for this key) } - keys := make([]keyInfo, 0, 15) - for keyRows.Next() { - var k string - var c int - if err := keyRows.Scan(&k, &c); err != nil { - return nil, fmt.Errorf("failed to scan key: %w", err) + order := make([]string, 0, 15) + accum := make(map[string]*facetAccum) + + for rows.Next() { + var key, val string + var cnt int + if err := rows.Scan(&key, &val, &cnt); err != nil { + return nil, fmt.Errorf("failed to scan facet row: %w", err) + } + + if _, seen := accum[key]; !seen { + accum[key] = &facetAccum{} + order = append(order, key) + } + fa := accum[key] + fa.totalCount += cnt + // Keep only the top 20 values per key (rows are ordered by cnt DESC within each key). + if len(fa.values) < 20 { + fa.values = append(fa.values, val) } - keys = append(keys, keyInfo{key: k, count: c}) } - if err := keyRows.Err(); err != nil { - return nil, fmt.Errorf("error iterating keys: %w", err) + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating facet rows: %w", err) } - // Step 2: For each key, get distinct values (top 20) - facets := make([]asset.PropertyFacet, 0, len(keys)) - for _, ki := range keys { - valQuery := fmt.Sprintf(` - SELECT DISTINCT a.properties->>$%d AS val - FROM assets a %s AND a.properties ? $%d - AND jsonb_typeof(a.properties->$%d) IN ('string', 'number', 'boolean') - ORDER BY val - LIMIT 20 - `, idx, where, idx, idx) + // Build facets: skip keys with total_count < 2, limit to top 15 keys by count. + const maxKeys = 15 + const minCount = 2 - valArgs := append(args, ki.key) //nolint:gocritic - valRows, err := r.db.QueryContext(ctx, valQuery, valArgs...) - if err != nil { - continue // Skip this key on error + facets := make([]asset.PropertyFacet, 0, len(order)) + for _, key := range order { + fa := accum[key] + if fa.totalCount < minCount { + continue } - - values := make([]string, 0, 20) - for valRows.Next() { - var v sql.NullString - if err := valRows.Scan(&v); err == nil && v.Valid && v.String != "" { - values = append(values, v.String) - } - } - _ = valRows.Close() - - if len(values) > 0 { - facets = append(facets, asset.PropertyFacet{ - Key: ki.key, - Label: formatPropertyLabel(ki.key), - Values: values, - Count: ki.count, - }) + facets = append(facets, asset.PropertyFacet{ + Key: key, + Label: formatPropertyLabel(key), + Values: fa.values, + Count: fa.totalCount, + }) + if len(facets) == maxKeys { + break } } @@ -1454,3 +1501,49 @@ func formatPropertyLabel(key string) string { } return strings.Join(words, " ") } + +// ListAllNodes fetches every asset for the tenant as lightweight graph nodes. +// Used by attack path scoring to build the full in-memory directed graph. +// Only the columns needed for scoring are fetched. +func (r *AssetRepository) ListAllNodes(ctx context.Context, tenantID shared.ID) ([]asset.AssetNode, error) { + const query = ` + SELECT + a.id, + a.name, + a.asset_type, + a.exposure, + a.criticality, + a.risk_score, + COALESCE(a.is_crown_jewel, FALSE), + COALESCE(fc.finding_count, 0) + FROM assets a + LEFT JOIN ( + SELECT asset_id, COUNT(*) AS finding_count + FROM findings + GROUP BY asset_id + ) fc ON fc.asset_id = a.id + WHERE a.tenant_id = $1 + ORDER BY a.created_at + ` + rows, err := r.db.QueryContext(ctx, query, tenantID.String()) + if err != nil { + return nil, fmt.Errorf("list all nodes: %w", err) + } + defer func() { _ = rows.Close() }() + + var nodes []asset.AssetNode + for rows.Next() { + var n asset.AssetNode + if scanErr := rows.Scan( + &n.ID, &n.Name, &n.AssetType, &n.Exposure, + &n.Criticality, &n.RiskScore, &n.IsCrownJewel, &n.FindingCount, + ); scanErr != nil { + return nil, fmt.Errorf("scan node: %w", scanErr) + } + nodes = append(nodes, n) + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("iterate nodes: %w", err) + } + return nodes, nil +} diff --git a/internal/infra/postgres/component_repository.go b/internal/infra/postgres/component_repository.go index 7924d048..b8269ce1 100644 --- a/internal/infra/postgres/component_repository.go +++ b/internal/infra/postgres/component_repository.go @@ -722,7 +722,10 @@ func (r *ComponentRepository) GetStats(ctx context.Context, tenantID shared.ID) AND v.cisa_kev_date_added IS NOT NULL AND f.status NOT IN ('resolved', 'false_positive') ` - _ = r.db.QueryRowContext(ctx, kevQuery, tenantID.String()).Scan(&stats.CisaKevComponents) + if err := r.db.QueryRowContext(ctx, kevQuery, tenantID.String()).Scan(&stats.CisaKevComponents); err != nil && !errors.Is(err, sql.ErrNoRows) { + // Non-critical metric — continue with zero value if query fails + stats.CisaKevComponents = 0 + } // License risk breakdown from component_licenses → licenses licenseRiskQuery := ` diff --git a/internal/infra/postgres/control_test_repository.go b/internal/infra/postgres/control_test_repository.go index 86b97d14..feb8fbd2 100644 --- a/internal/infra/postgres/control_test_repository.go +++ b/internal/infra/postgres/control_test_repository.go @@ -257,3 +257,70 @@ func (r *ControlTestRepository) GetStatsByFramework(ctx context.Context, tenantI return stats, nil } + +// ListOverdue returns control tests that have not been tested in at least staleDays days, +// or that have never been tested (last_tested_at IS NULL). Scans across all tenants. +// Returns at most limit rows ordered by longest-overdue first. +func (r *ControlTestRepository) ListOverdue(ctx context.Context, staleDays int, limit int) ([]*simulation.OverdueControlTest, error) { + if limit <= 0 { + limit = 500 + } + query := ` + SELECT tenant_id, id, name, framework, + COALESCE(EXTRACT(DAY FROM NOW() - last_tested_at)::int, $1) AS days_since_tested + FROM control_tests + WHERE last_tested_at IS NULL + OR last_tested_at < NOW() - ($2 || ' days')::interval + ORDER BY days_since_tested DESC + LIMIT $3` + + rows, err := r.db.QueryContext(ctx, query, staleDays, staleDays, limit) + if err != nil { + return nil, fmt.Errorf("failed to list overdue control tests: %w", err) + } + defer rows.Close() + + results := make([]*simulation.OverdueControlTest, 0) + for rows.Next() { + var tenantIDStr, idStr, name, framework string + var daysSince int + if err := rows.Scan(&tenantIDStr, &idStr, &name, &framework, &daysSince); err != nil { + return nil, fmt.Errorf("failed to scan overdue control test: %w", err) + } + tenantID, err := shared.IDFromString(tenantIDStr) + if err != nil { + continue + } + ctID, err := shared.IDFromString(idStr) + if err != nil { + continue + } + results = append(results, &simulation.OverdueControlTest{ + TenantID: tenantID, + ControlTestID: ctID, + Name: name, + Framework: framework, + DaysSinceTested: daysSince, + }) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("failed to iterate overdue control tests: %w", err) + } + return results, nil +} + +// MarkOverdue resets a control test's status to 'untested' to signal stale detection coverage. +// This is a deliberate write — the scheduler only calls it when a test has genuinely lapsed. +func (r *ControlTestRepository) MarkOverdue(ctx context.Context, tenantID, id shared.ID) error { + query := `UPDATE control_tests SET status = $3, updated_at = $4 + WHERE tenant_id = $1 AND id = $2 AND status NOT IN ('untested')` + + _, err := r.db.ExecContext(ctx, query, + tenantID.String(), id.String(), + string(simulation.ControlTestStatusUntested), time.Now(), + ) + if err != nil { + return fmt.Errorf("failed to mark control test overdue: %w", err) + } + return nil +} diff --git a/internal/infra/postgres/finding_repository.go b/internal/infra/postgres/finding_repository.go index 1850f12a..a6cced25 100644 --- a/internal/infra/postgres/finding_repository.go +++ b/internal/infra/postgres/finding_repository.go @@ -795,6 +795,40 @@ func (r *FindingRepository) Delete(ctx context.Context, tenantID, id shared.ID) return nil } +// UpdateWorkItemURIs sets the work_item_uris column for a finding. +// Only modifies work_item_uris — all other fields are untouched. +func (r *FindingRepository) UpdateWorkItemURIs(ctx context.Context, tenantID, id shared.ID, uris []string) error { + if uris == nil { + uris = []string{} + } + query := `UPDATE findings SET work_item_uris = $3, updated_at = NOW() + WHERE tenant_id = $1 AND id = $2` + result, err := r.db.ExecContext(ctx, query, tenantID.String(), id.String(), pq.Array(uris)) + if err != nil { + return fmt.Errorf("failed to update work item uris: %w", err) + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return vulnerability.FindingNotFoundError(id) + } + return nil +} + +// GetByWorkItemURI retrieves a finding that contains the given work item URI. +// Used by the Jira webhook receiver to route inbound status updates back to findings. +func (r *FindingRepository) GetByWorkItemURI(ctx context.Context, tenantID shared.ID, uri string) (*vulnerability.Finding, error) { + query := r.selectQuery() + ` WHERE tenant_id = $1 AND $2 = ANY(work_item_uris) LIMIT 1` + row := r.db.QueryRowContext(ctx, query, tenantID.String(), uri) + finding, err := r.scanFinding(row, fmt.Errorf("%w: finding not found for work item URI", shared.ErrNotFound)) + if err != nil { + return nil, fmt.Errorf("failed to get finding by work item URI: %w", err) + } + return finding, nil +} + // List retrieves findings matching the filter with pagination. func (r *FindingRepository) List(ctx context.Context, filter vulnerability.FindingFilter, opts vulnerability.FindingListOptions, page pagination.Pagination) (pagination.Result[*vulnerability.Finding], error) { baseQuery := r.selectQuery() diff --git a/internal/infra/postgres/relationship_suggestion_repository.go b/internal/infra/postgres/relationship_suggestion_repository.go new file mode 100644 index 00000000..d4cbb9c8 --- /dev/null +++ b/internal/infra/postgres/relationship_suggestion_repository.go @@ -0,0 +1,412 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/openctemio/api/pkg/domain/relationship" + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/pagination" +) + +// RelationshipSuggestionRepository implements relationship.SuggestionRepository using PostgreSQL. +type RelationshipSuggestionRepository struct { + db *DB +} + +// NewRelationshipSuggestionRepository creates a new RelationshipSuggestionRepository. +func NewRelationshipSuggestionRepository(db *DB) *RelationshipSuggestionRepository { + return &RelationshipSuggestionRepository{db: db} +} + +// Create persists a new suggestion. +func (r *RelationshipSuggestionRepository) Create(ctx context.Context, s *relationship.Suggestion) error { + query := ` + INSERT INTO relationship_suggestions ( + id, tenant_id, source_asset_id, target_asset_id, + relationship_type, reason, confidence, status, created_at + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ` + + _, err := r.db.ExecContext(ctx, query, + s.ID().String(), + s.TenantID().String(), + s.SourceAssetID().String(), + s.TargetAssetID().String(), + s.RelationshipType(), + s.Reason(), + s.Confidence(), + s.Status(), + s.CreatedAt(), + ) + if err != nil { + if isUniqueViolation(err) { + return fmt.Errorf("%w: suggestion already exists", shared.ErrAlreadyExists) + } + return fmt.Errorf("failed to create suggestion: %w", err) + } + + return nil +} + +// CreateBatch inserts multiple suggestions, skipping duplicates via ON CONFLICT DO NOTHING. +func (r *RelationshipSuggestionRepository) CreateBatch(ctx context.Context, suggestions []*relationship.Suggestion) (int, error) { + if len(suggestions) == 0 { + return 0, nil + } + + query := ` + INSERT INTO relationship_suggestions ( + id, tenant_id, source_asset_id, target_asset_id, + relationship_type, reason, confidence, status, created_at + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (tenant_id, source_asset_id, target_asset_id, relationship_type) DO NOTHING + ` + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + + stmt, err := tx.PrepareContext(ctx, query) + if err != nil { + return 0, fmt.Errorf("failed to prepare statement: %w", err) + } + defer func() { _ = stmt.Close() }() + + created := 0 + for _, s := range suggestions { + result, execErr := stmt.ExecContext(ctx, + s.ID().String(), + s.TenantID().String(), + s.SourceAssetID().String(), + s.TargetAssetID().String(), + s.RelationshipType(), + s.Reason(), + s.Confidence(), + s.Status(), + s.CreatedAt(), + ) + if execErr != nil { + return created, fmt.Errorf("failed to insert suggestion: %w", execErr) + } + rowsAffected, _ := result.RowsAffected() + if rowsAffected > 0 { + created++ + } + } + + if err := tx.Commit(); err != nil { + return 0, fmt.Errorf("failed to commit transaction: %w", err) + } + + return created, nil +} + +// GetByID retrieves a suggestion by ID within a tenant. +func (r *RelationshipSuggestionRepository) GetByID(ctx context.Context, tenantID, id shared.ID) (*relationship.Suggestion, error) { + query := ` + SELECT id, tenant_id, source_asset_id, target_asset_id, + relationship_type, reason, confidence, status, + reviewed_by, reviewed_at, created_at + FROM relationship_suggestions + WHERE tenant_id = $1 AND id = $2 + ` + + row := r.db.QueryRowContext(ctx, query, tenantID.String(), id.String()) + s, err := r.scanSuggestion(row) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, relationship.ErrSuggestionNotFound + } + return nil, fmt.Errorf("failed to get suggestion: %w", err) + } + + return s, nil +} + +// ListPending retrieves pending suggestions for a tenant with pagination and optional search. +func (r *RelationshipSuggestionRepository) ListPending(ctx context.Context, tenantID shared.ID, search string, page pagination.Pagination) (pagination.Result[*relationship.Suggestion], error) { + // Build WHERE clause + where := "rs.tenant_id = $1 AND rs.status = 'pending'" + args := []any{tenantID.String()} + idx := 2 + + if search != "" { + where += fmt.Sprintf(` AND (sa.name ILIKE $%d OR ta.name ILIKE $%d)`, idx, idx) + args = append(args, "%"+escapeLikePattern(search)+"%") + idx++ + } + + // Count total + countQuery := fmt.Sprintf(` + SELECT COUNT(*) + FROM relationship_suggestions rs + LEFT JOIN assets sa ON rs.source_asset_id = sa.id + LEFT JOIN assets ta ON rs.target_asset_id = ta.id + WHERE %s`, where) + var total int64 + if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return pagination.Result[*relationship.Suggestion]{}, fmt.Errorf("failed to count suggestions: %w", err) + } + + if total == 0 { + return pagination.NewResult(make([]*relationship.Suggestion, 0), 0, page), nil + } + + // Fetch page with JOINed asset names + query := fmt.Sprintf(` + SELECT rs.id, rs.tenant_id, rs.source_asset_id, rs.target_asset_id, + rs.relationship_type, rs.reason, rs.confidence, rs.status, + rs.reviewed_by, rs.reviewed_at, rs.created_at, + COALESCE(sa.name, ''), COALESCE(sa.asset_type, ''), + COALESCE(ta.name, ''), COALESCE(ta.asset_type, '') + FROM relationship_suggestions rs + LEFT JOIN assets sa ON rs.source_asset_id = sa.id + LEFT JOIN assets ta ON rs.target_asset_id = ta.id + WHERE %s + ORDER BY rs.confidence DESC, rs.created_at DESC + LIMIT $%d OFFSET $%d + `, where, idx, idx+1) + args = append(args, page.Limit(), page.Offset()) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return pagination.Result[*relationship.Suggestion]{}, fmt.Errorf("failed to list suggestions: %w", err) + } + defer func() { _ = rows.Close() }() + + suggestions := make([]*relationship.Suggestion, 0, 100) + for rows.Next() { + s, scanErr := r.scanSuggestionWithAssets(rows) + if scanErr != nil { + return pagination.Result[*relationship.Suggestion]{}, fmt.Errorf("failed to scan suggestion: %w", scanErr) + } + suggestions = append(suggestions, s) + } + if err = rows.Err(); err != nil { + return pagination.Result[*relationship.Suggestion]{}, fmt.Errorf("failed to iterate suggestions: %w", err) + } + + return pagination.NewResult(suggestions, total, page), nil +} + +// UpdateStatus updates the status, reviewed_by, and reviewed_at of a suggestion. +func (r *RelationshipSuggestionRepository) UpdateStatus(ctx context.Context, s *relationship.Suggestion) error { + query := ` + UPDATE relationship_suggestions + SET status = $3, reviewed_by = $4, reviewed_at = $5 + WHERE tenant_id = $1 AND id = $2 + ` + + result, err := r.db.ExecContext(ctx, query, + s.TenantID().String(), + s.ID().String(), + s.Status(), + nullIDPtr(s.ReviewedBy()), + s.ReviewedAt(), + ) + if err != nil { + return fmt.Errorf("failed to update suggestion status: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return relationship.ErrSuggestionNotFound + } + + return nil +} + +// CountPending returns the number of pending suggestions for a tenant. +func (r *RelationshipSuggestionRepository) CountPending(ctx context.Context, tenantID shared.ID) (int64, error) { + query := `SELECT COUNT(*) FROM relationship_suggestions WHERE tenant_id = $1 AND status = 'pending'` + var count int64 + if err := r.db.QueryRowContext(ctx, query, tenantID.String()).Scan(&count); err != nil { + return 0, fmt.Errorf("failed to count pending suggestions: %w", err) + } + return count, nil +} + +// DeleteByAssetID deletes all suggestions involving a given asset. +func (r *RelationshipSuggestionRepository) DeleteByAssetID(ctx context.Context, tenantID, assetID shared.ID) error { + query := ` + DELETE FROM relationship_suggestions + WHERE tenant_id = $1 AND (source_asset_id = $2 OR target_asset_id = $2) + ` + _, err := r.db.ExecContext(ctx, query, tenantID.String(), assetID.String()) + if err != nil { + return fmt.Errorf("failed to delete suggestions by asset: %w", err) + } + return nil +} + +// DeletePending removes all pending suggestions for a tenant (used before re-scan). +func (r *RelationshipSuggestionRepository) DeletePending(ctx context.Context, tenantID shared.ID) error { + _, err := r.db.ExecContext(ctx, + `DELETE FROM relationship_suggestions WHERE tenant_id = $1 AND status = 'pending'`, + tenantID.String(), + ) + if err != nil { + return fmt.Errorf("failed to delete pending suggestions: %w", err) + } + return nil +} + +// ApproveAll marks all pending suggestions as approved and returns them. +func (r *RelationshipSuggestionRepository) ApproveAll(ctx context.Context, tenantID, reviewerID shared.ID) ([]*relationship.Suggestion, error) { + now := time.Now().UTC() + + query := ` + UPDATE relationship_suggestions + SET status = 'approved', reviewed_by = $2, reviewed_at = $3 + WHERE tenant_id = $1 AND status = 'pending' + RETURNING id, tenant_id, source_asset_id, target_asset_id, + relationship_type, reason, confidence, status, + reviewed_by, reviewed_at, created_at + ` + + rows, err := r.db.QueryContext(ctx, query, tenantID.String(), reviewerID.String(), now) + if err != nil { + return nil, fmt.Errorf("failed to approve all suggestions: %w", err) + } + defer func() { _ = rows.Close() }() + + suggestions := make([]*relationship.Suggestion, 0) + for rows.Next() { + s, scanErr := r.scanSuggestion(rows) + if scanErr != nil { + return nil, fmt.Errorf("failed to scan approved suggestion: %w", scanErr) + } + suggestions = append(suggestions, s) + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("failed to iterate approved suggestions: %w", err) + } + + return suggestions, nil +} + +// UpdateRelationshipType updates only the relationship_type of a pending suggestion. +func (r *RelationshipSuggestionRepository) UpdateRelationshipType(ctx context.Context, tenantID, id shared.ID, relType string) error { + query := ` + UPDATE relationship_suggestions + SET relationship_type = $3 + WHERE tenant_id = $1 AND id = $2 AND status = 'pending' + ` + result, err := r.db.ExecContext(ctx, query, tenantID.String(), id.String(), relType) + if err != nil { + return fmt.Errorf("failed to update suggestion relationship type: %w", err) + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return relationship.ErrSuggestionNotFound + } + return nil +} + +// ============================================================================= +// Internal helpers +// ============================================================================= + +// suggestionScanner is satisfied by both *sql.Row and *sql.Rows. +type suggestionScanner interface { + Scan(dest ...any) error +} + +func (r *RelationshipSuggestionRepository) scanSuggestion(row suggestionScanner) (*relationship.Suggestion, error) { + var ( + id string + tenantID string + sourceAssetID string + targetAssetID string + relType string + reason string + confidence float64 + status string + reviewedByStr sql.NullString + reviewedAt *time.Time + createdAt time.Time + ) + + err := row.Scan( + &id, &tenantID, &sourceAssetID, &targetAssetID, + &relType, &reason, &confidence, &status, + &reviewedByStr, &reviewedAt, &createdAt, + ) + if err != nil { + return nil, err + } + + var reviewedBy *shared.ID + if reviewedByStr.Valid { + parsedID := shared.MustIDFromString(reviewedByStr.String) + reviewedBy = &parsedID + } + + return relationship.ReconstituteSuggestion( + shared.MustIDFromString(id), + shared.MustIDFromString(tenantID), + shared.MustIDFromString(sourceAssetID), + shared.MustIDFromString(targetAssetID), + relType, + reason, + confidence, + status, + reviewedBy, + reviewedAt, + createdAt, + ), nil +} + +// scanSuggestionWithAssets scans a suggestion row that includes JOINed asset name/type columns. +func (r *RelationshipSuggestionRepository) scanSuggestionWithAssets(row suggestionScanner) (*relationship.Suggestion, error) { + var ( + id, tenantID, sourceAssetID, targetAssetID string + relType, reason, status string + confidence float64 + reviewedByStr sql.NullString + reviewedAt *time.Time + createdAt time.Time + srcName, srcType, tgtName, tgtType string + ) + + err := row.Scan( + &id, &tenantID, &sourceAssetID, &targetAssetID, + &relType, &reason, &confidence, &status, + &reviewedByStr, &reviewedAt, &createdAt, + &srcName, &srcType, &tgtName, &tgtType, + ) + if err != nil { + return nil, err + } + + var reviewedBy *shared.ID + if reviewedByStr.Valid { + parsedID := shared.MustIDFromString(reviewedByStr.String) + reviewedBy = &parsedID + } + + s := relationship.ReconstituteSuggestion( + shared.MustIDFromString(id), + shared.MustIDFromString(tenantID), + shared.MustIDFromString(sourceAssetID), + shared.MustIDFromString(targetAssetID), + relType, reason, confidence, status, + reviewedBy, reviewedAt, createdAt, + ) + s.SetAssetInfo(srcName, srcType, tgtName, tgtType) + return s, nil +} diff --git a/migrations/000133_fix_subdomain_types.down.sql b/migrations/000133_fix_subdomain_types.down.sql new file mode 100644 index 00000000..e91a861b --- /dev/null +++ b/migrations/000133_fix_subdomain_types.down.sql @@ -0,0 +1,4 @@ +-- Cannot reliably reverse — subdomains may have been correctly typed before. +-- This migration only fixes mistyped assets, so rollback is a no-op. +-- If needed, manually UPDATE specific assets back to 'domain'. +SELECT 1; diff --git a/migrations/000133_fix_subdomain_types.up.sql b/migrations/000133_fix_subdomain_types.up.sql new file mode 100644 index 00000000..9f199afb --- /dev/null +++ b/migrations/000133_fix_subdomain_types.up.sql @@ -0,0 +1,21 @@ +-- ============================================================================= +-- Migration 000133: Fix subdomains incorrectly typed as 'domain' +-- ============================================================================= +-- Detects subdomains by checking if a parent root domain exists in the system. +-- A domain X is a subdomain if another domain Y exists where X ends with '.Y' +-- and Y has fewer dots (is a shorter/higher-level domain). +-- ============================================================================= + +-- Fix: assets typed as 'domain' that are actually subdomains +-- (they have a parent domain in the same tenant) +UPDATE assets a1 +SET asset_type = 'subdomain' +WHERE a1.asset_type = 'domain' + AND EXISTS ( + SELECT 1 FROM assets a2 + WHERE a2.tenant_id = a1.tenant_id + AND a2.asset_type = 'domain' + AND a2.id != a1.id + AND a1.name LIKE '%.' || a2.name + AND length(a2.name) < length(a1.name) + ); diff --git a/migrations/000134_extract_dns_fields_from_properties.down.sql b/migrations/000134_extract_dns_fields_from_properties.down.sql new file mode 100644 index 00000000..17ba0a52 --- /dev/null +++ b/migrations/000134_extract_dns_fields_from_properties.down.sql @@ -0,0 +1,4 @@ +-- Rollback: remove extracted flat fields (original nested data preserved) +UPDATE assets SET properties = properties - 'record_type' - 'resolved_ip' - 'cname_target' - 'ttl' - 'dns_record_types' - 'resolved_ips' - 'dns_record_count' +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties IS NOT NULL; diff --git a/migrations/000134_extract_dns_fields_from_properties.up.sql b/migrations/000134_extract_dns_fields_from_properties.up.sql new file mode 100644 index 00000000..53da2e6c --- /dev/null +++ b/migrations/000134_extract_dns_fields_from_properties.up.sql @@ -0,0 +1,92 @@ +-- ============================================================================= +-- Migration 000134: Extract DNS fields from nested properties to flat fields +-- ============================================================================= +-- Collector stores DNS data as: +-- domain: {"domain": {"dns_records": [{"ttl":300,"name":"x","type":"A","value":"1.2.3.4"}]}, "collector_type":"gcp-dns", "collector_source":"vndirect-compute"} +-- subdomain: {"domain": {"dns_records": [...]}, "root_domain":"parent.com", "collector_type":"gcp-dns", ...} +-- +-- UI reads flat fields: record_type, provider, registrar, nameserver, ip_address +-- This migration extracts from nested JSONB → flat top-level properties keys. +-- ============================================================================= + +-- Step 1: Extract first DNS record type → record_type (e.g., "A", "CNAME", "AAAA", "MX") +UPDATE assets SET properties = properties || jsonb_build_object( + 'record_type', (properties->'domain'->'dns_records'->0->>'type') +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->>'record_type' IS NULL OR properties->>'record_type' = ''); + +-- Step 2: Extract resolved IP from A/AAAA records → resolved_ip +UPDATE assets SET properties = properties || jsonb_build_object( + 'resolved_ip', (properties->'domain'->'dns_records'->0->>'value') +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->'domain'->'dns_records'->0->>'type') IN ('A', 'AAAA') + AND (properties->>'resolved_ip' IS NULL OR properties->>'resolved_ip' = ''); + +-- Step 3: Extract CNAME target → cname_target +UPDATE assets SET properties = properties || jsonb_build_object( + 'cname_target', (properties->'domain'->'dns_records'->0->>'value') +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->'domain'->'dns_records'->0->>'type') = 'CNAME' + AND (properties->>'cname_target' IS NULL OR properties->>'cname_target' = ''); + +-- Step 4: Extract TTL → ttl +UPDATE assets SET properties = properties || jsonb_build_object( + 'ttl', (properties->'domain'->'dns_records'->0->'ttl') +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->>'ttl' IS NULL); + +-- Step 5: Extract all unique record types → dns_record_types (comma-separated) +UPDATE assets SET properties = properties || jsonb_build_object( + 'dns_record_types', ( + SELECT string_agg(DISTINCT rec->>'type', ', ' ORDER BY rec->>'type') + FROM jsonb_array_elements(properties->'domain'->'dns_records') AS rec + ) +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->>'dns_record_types' IS NULL OR properties->>'dns_record_types' = ''); + +-- Step 6: Extract all resolved IPs → resolved_ips (comma-separated, A/AAAA only) +UPDATE assets SET properties = properties || jsonb_build_object( + 'resolved_ips', ( + SELECT string_agg(DISTINCT rec->>'value', ', ') + FROM jsonb_array_elements(properties->'domain'->'dns_records') AS rec + WHERE rec->>'type' IN ('A', 'AAAA') + ) +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND jsonb_array_length(properties->'domain'->'dns_records') > 0 + AND (properties->>'resolved_ips' IS NULL OR properties->>'resolved_ips' = ''); + +-- Step 7: Promote collector_type and collector_source to flat fields (already flat, just ensure present) +-- These are already at top level — no action needed. + +-- Step 8: Promote root_domain for subdomains (strip trailing dot) +UPDATE assets SET properties = properties || jsonb_build_object( + 'root_domain', RTRIM(properties->>'root_domain', '.') +) +WHERE asset_type = 'subdomain' + AND properties->>'root_domain' IS NOT NULL + AND properties->>'root_domain' LIKE '%.'; + +-- Step 9: Extract dns_record_count +UPDATE assets SET properties = properties || jsonb_build_object( + 'dns_record_count', jsonb_array_length(properties->'domain'->'dns_records') +) +WHERE (asset_type = 'domain' OR asset_type = 'subdomain') + AND properties->'domain'->'dns_records' IS NOT NULL + AND (properties->>'dns_record_count' IS NULL); diff --git a/migrations/000135_normalize_property_keys_snake_case.down.sql b/migrations/000135_normalize_property_keys_snake_case.down.sql new file mode 100644 index 00000000..616a64fc --- /dev/null +++ b/migrations/000135_normalize_property_keys_snake_case.down.sql @@ -0,0 +1,3 @@ +-- Down migration: No-op (cannot reliably reverse snake_case → camelCase) +-- The data is still correct, just in a different key format. +-- If rollback is needed, the frontend fallback reads both formats. diff --git a/migrations/000135_normalize_property_keys_snake_case.up.sql b/migrations/000135_normalize_property_keys_snake_case.up.sql new file mode 100644 index 00000000..221d2b3c --- /dev/null +++ b/migrations/000135_normalize_property_keys_snake_case.up.sql @@ -0,0 +1,68 @@ +-- Migration: Normalize camelCase JSONB property keys to snake_case +-- This ensures consistency: collectors send snake_case, UI reads snake_case. +-- The backend PromoteKnownProperties now auto-converts camelCase on ingest, +-- but existing data needs a one-time cleanup. + +-- Helper function: convert a single camelCase string to snake_case +CREATE OR REPLACE FUNCTION pg_temp.camel_to_snake(s TEXT) RETURNS TEXT AS $$ +DECLARE + result TEXT := ''; + ch CHAR; + prev_ch CHAR := ''; + i INT; +BEGIN + FOR i IN 1..length(s) LOOP + ch := substr(s, i, 1); + IF ch >= 'A' AND ch <= 'Z' THEN + -- Insert underscore before uppercase if preceded by lowercase + IF prev_ch >= 'a' AND prev_ch <= 'z' THEN + result := result || '_'; + -- Or if preceded by uppercase and followed by lowercase (e.g., "memoryGB" → "memory_gb") + ELSIF prev_ch >= 'A' AND prev_ch <= 'Z' AND i < length(s) AND substr(s, i+1, 1) >= 'a' AND substr(s, i+1, 1) <= 'z' THEN + result := result || '_'; + END IF; + result := result || lower(ch); + ELSE + result := result || ch; + END IF; + prev_ch := ch; + END LOOP; + RETURN result; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +-- Helper function: recursively normalize all keys in a JSONB object +CREATE OR REPLACE FUNCTION pg_temp.normalize_jsonb_keys(obj JSONB) RETURNS JSONB AS $$ +DECLARE + result JSONB := '{}'; + kv RECORD; + new_key TEXT; +BEGIN + IF obj IS NULL OR jsonb_typeof(obj) != 'object' THEN + RETURN obj; + END IF; + FOR kv IN SELECT * FROM jsonb_each(obj) LOOP + new_key := pg_temp.camel_to_snake(kv.key); + -- If snake_case key already exists, don't overwrite it + IF result ? new_key AND new_key != kv.key THEN + CONTINUE; + END IF; + -- Recursively normalize nested objects (but not arrays) + IF jsonb_typeof(kv.value) = 'object' THEN + result := result || jsonb_build_object(new_key, pg_temp.normalize_jsonb_keys(kv.value)); + ELSE + result := result || jsonb_build_object(new_key, kv.value); + END IF; + END LOOP; + RETURN result; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +-- Apply normalization to all assets with properties containing camelCase keys +-- Only update rows where at least one key changes (avoid unnecessary writes) +UPDATE assets +SET properties = pg_temp.normalize_jsonb_keys(properties), + updated_at = NOW() +WHERE properties IS NOT NULL + AND properties != '{}'::jsonb + AND properties::text ~ '[a-z][A-Z]'; -- Quick regex check: has camelCase pattern diff --git a/migrations/000136_relationship_suggestions.down.sql b/migrations/000136_relationship_suggestions.down.sql new file mode 100644 index 00000000..03f18f89 --- /dev/null +++ b/migrations/000136_relationship_suggestions.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS relationship_suggestions; diff --git a/migrations/000136_relationship_suggestions.up.sql b/migrations/000136_relationship_suggestions.up.sql new file mode 100644 index 00000000..8ee8adbb --- /dev/null +++ b/migrations/000136_relationship_suggestions.up.sql @@ -0,0 +1,19 @@ +CREATE TABLE IF NOT EXISTS relationship_suggestions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tenant_id UUID NOT NULL REFERENCES tenants(id) ON DELETE CASCADE, + source_asset_id UUID NOT NULL REFERENCES assets(id) ON DELETE CASCADE, + target_asset_id UUID NOT NULL REFERENCES assets(id) ON DELETE CASCADE, + relationship_type VARCHAR(50) NOT NULL, + reason TEXT NOT NULL, + confidence DECIMAL(3,2) NOT NULL DEFAULT 1.00, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + reviewed_by UUID REFERENCES users(id), + reviewed_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT chk_suggestion_status CHECK (status IN ('pending', 'approved', 'dismissed')), + CONSTRAINT uq_suggestion UNIQUE(tenant_id, source_asset_id, target_asset_id, relationship_type) +); + +CREATE INDEX idx_suggestions_tenant_status ON relationship_suggestions(tenant_id, status); +CREATE INDEX idx_suggestions_source ON relationship_suggestions(source_asset_id); +CREATE INDEX idx_suggestions_target ON relationship_suggestions(target_asset_id); diff --git a/pkg/domain/asset/relationship_repository.go b/pkg/domain/asset/relationship_repository.go index 08bc95f0..7be28316 100644 --- a/pkg/domain/asset/relationship_repository.go +++ b/pkg/domain/asset/relationship_repository.go @@ -41,6 +41,22 @@ type RelationshipRepository interface { // types are actually being used and trim or extend the registry // based on real data instead of guessing. CountByType(ctx context.Context, tenantID shared.ID) (map[RelationshipType]int64, error) + + // ListAllEdges fetches every relationship for the tenant as lightweight + // graph edges. This is used by attack path scoring which needs the full + // graph in-memory. The result is intentionally minimal (just IDs + type) + // so the query is fast even for large tenants. + ListAllEdges(ctx context.Context, tenantID shared.ID) ([]RelationshipEdge, error) +} + +// RelationshipEdge is a lightweight representation of a relationship used +// for in-memory graph traversal (attack path scoring). It carries only the +// fields needed to build the adjacency list — no heavy asset data. +type RelationshipEdge struct { + SourceAssetID string + TargetAssetID string + Type RelationshipType + ImpactWeight int } // RelationshipFilter defines filtering options for relationship queries. diff --git a/pkg/domain/asset/relationship_types_generated.go b/pkg/domain/asset/relationship_types_generated.go index bde553b1..6c016e65 100644 --- a/pkg/domain/asset/relationship_types_generated.go +++ b/pkg/domain/asset/relationship_types_generated.go @@ -151,6 +151,10 @@ var RelationshipTypeRegistry = map[RelationshipType]RelationshipTypeMetadata{ Sources: []string{ "repository" }, Targets: []string{ "container_image" }, }, + { + Sources: []string{ "domain" }, + Targets: []string{ "subdomain" }, + }, }, }, RelTypeExposes: { @@ -179,10 +183,10 @@ var RelationshipTypeRegistry = map[RelationshipType]RelationshipTypeMetadata{ Category: "attack_surface_mapping", Direct: "Resolves To", Inverse: "Resolved By", - Description: "Literal DNS A/AAAA resolution — a domain resolves to an IP record or a load balancer that owns that IP. STRICT semantic: target MUST be the network endpoint, not the server that happens to own the IP. For \"this domain leads to this server / website\" use Exposes. For subdomain → parent domain or CNAME aliases use Cname Of.", + Description: "Literal DNS A/AAAA resolution — a domain or subdomain resolves to an IP record or a load balancer that owns that IP. STRICT semantic: target MUST be the network endpoint, not the server that happens to own the IP. For \"this domain leads to this server / website\" use Exposes. For subdomain → parent domain hierarchy use Contains.", Constraints: []RelationshipConstraint{ { - Sources: []string{ "domain" }, + Sources: []string{ "domain", "subdomain" }, Targets: []string{ "ip_address", "load_balancer" }, }, }, @@ -192,11 +196,11 @@ var RelationshipTypeRegistry = map[RelationshipType]RelationshipTypeMetadata{ Category: "attack_surface_mapping", Direct: "CNAME Of", Inverse: "Has CNAME", - Description: "DNS aliasing — this name is a CNAME for that name. Also used for subdomain → parent domain logical relationships. Distinct from Resolves To which captures the final IP/host record.", + Description: "DNS CNAME aliasing — this name is a CNAME record pointing to another name. Strictly for actual DNS CNAME records, NOT for subdomain hierarchy (use Contains for that). Distinct from Resolves To which captures the final A/AAAA IP record.", Constraints: []RelationshipConstraint{ { - Sources: []string{ "domain" }, - Targets: []string{ "domain" }, + Sources: []string{ "domain", "subdomain" }, + Targets: []string{ "domain", "subdomain" }, }, }, }, diff --git a/pkg/domain/asset/repository.go b/pkg/domain/asset/repository.go index dbda29c9..26d7c826 100644 --- a/pkg/domain/asset/repository.go +++ b/pkg/domain/asset/repository.go @@ -79,7 +79,7 @@ type Repository interface { // ListDistinctTags returns distinct tags across all assets for a tenant. // Supports prefix filtering for autocomplete and a limit for result size. - ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, limit int) ([]string, error) + ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, types []string, limit int) ([]string, error) // GetAssetTypeBreakdown returns total and exposed counts grouped by asset_type in a single query. // This replaces the N+1 pattern of calling Count() per type. @@ -98,10 +98,30 @@ type Repository interface { // GetAggregateStats computes all asset statistics using SQL aggregation. // Filters: types (asset_type ANY), tags (overlap, matches List semantics). - GetAggregateStats(ctx context.Context, tenantID shared.ID, types []string, tags []string, subType string) (*AggregateStats, error) + GetAggregateStats(ctx context.Context, tenantID shared.ID, types []string, tags []string, subType string, countByFields ...string) (*AggregateStats, error) // GetPropertyFacets returns distinct JSONB property keys and their top values for faceted filtering. GetPropertyFacets(ctx context.Context, tenantID shared.ID, types []string, subType string) ([]PropertyFacet, error) + + // ListAllNodes fetches every asset for the tenant as lightweight graph nodes. + // Used exclusively by attack path scoring which needs the full set of assets + // in-memory. Columns are minimal (id, name, type, exposure, criticality, + // risk_score, is_crown_jewel, finding_count) to keep the query fast. + ListAllNodes(ctx context.Context, tenantID shared.ID) ([]AssetNode, error) +} + +// AssetNode is a lightweight representation of an asset used for in-memory +// graph traversal (attack path scoring). It carries only the fields needed +// to build adjacency lists and compute path risk scores. +type AssetNode struct { + ID string + Name string + AssetType string + Exposure string + Criticality string + RiskScore int + IsCrownJewel bool + FindingCount int } // PropertyFacet represents a property key with its distinct values for filtering UI. @@ -125,6 +145,8 @@ type AggregateStats struct { FindingsTotal int HighRiskCount int RiskScoreAvg float64 + // MetadataCounts: JSONB property value counts. Key=field, Value=map[value]count. + MetadataCounts map[string]map[string]int } // AssetTypeStats holds per-type aggregate counts. diff --git a/pkg/domain/relationship/suggestion.go b/pkg/domain/relationship/suggestion.go new file mode 100644 index 00000000..ffaab35e --- /dev/null +++ b/pkg/domain/relationship/suggestion.go @@ -0,0 +1,172 @@ +// Package relationship provides domain entities for relationship suggestions. +package relationship + +import ( + "context" + "fmt" + "time" + + "github.com/openctemio/api/pkg/domain/shared" + "github.com/openctemio/api/pkg/pagination" +) + +// Suggestion status constants. +const ( + SuggestionPending = "pending" + SuggestionApproved = "approved" + SuggestionDismissed = "dismissed" +) + +// Suggestion represents a suggested relationship between two assets. +type Suggestion struct { + id shared.ID + tenantID shared.ID + sourceAssetID shared.ID + targetAssetID shared.ID + relationshipType string + reason string + confidence float64 + status string + reviewedBy *shared.ID + reviewedAt *time.Time + createdAt time.Time + // Enrichment fields (populated by JOINs, not stored) + sourceAssetName string + sourceAssetType string + targetAssetName string + targetAssetType string +} + +// NewSuggestion creates a new Suggestion with validation. +func NewSuggestion( + tenantID, sourceAssetID, targetAssetID shared.ID, + relType, reason string, + confidence float64, +) (*Suggestion, error) { + if tenantID.IsZero() { + return nil, fmt.Errorf("%w: tenant ID is required", shared.ErrValidation) + } + if sourceAssetID.IsZero() { + return nil, fmt.Errorf("%w: source asset ID is required", shared.ErrValidation) + } + if targetAssetID.IsZero() { + return nil, fmt.Errorf("%w: target asset ID is required", shared.ErrValidation) + } + if relType == "" { + return nil, fmt.Errorf("%w: relationship type is required", shared.ErrValidation) + } + if reason == "" { + return nil, fmt.Errorf("%w: reason is required", shared.ErrValidation) + } + if confidence < 0 || confidence > 1 { + return nil, fmt.Errorf("%w: confidence must be between 0 and 1", shared.ErrValidation) + } + + now := time.Now().UTC() + return &Suggestion{ + id: shared.NewID(), + tenantID: tenantID, + sourceAssetID: sourceAssetID, + targetAssetID: targetAssetID, + relationshipType: relType, + reason: reason, + confidence: confidence, + status: SuggestionPending, + createdAt: now, + }, nil +} + +// ReconstituteSuggestion rebuilds a Suggestion from persistence. +func ReconstituteSuggestion( + id, tenantID, sourceAssetID, targetAssetID shared.ID, + relType, reason string, + confidence float64, + status string, + reviewedBy *shared.ID, + reviewedAt *time.Time, + createdAt time.Time, +) *Suggestion { + return &Suggestion{ + id: id, + tenantID: tenantID, + sourceAssetID: sourceAssetID, + targetAssetID: targetAssetID, + relationshipType: relType, + reason: reason, + confidence: confidence, + status: status, + reviewedBy: reviewedBy, + reviewedAt: reviewedAt, + createdAt: createdAt, + } +} + +// Approve marks the suggestion as approved. +func (s *Suggestion) Approve(reviewerID shared.ID) { + s.status = SuggestionApproved + s.reviewedBy = &reviewerID + now := time.Now().UTC() + s.reviewedAt = &now +} + +// UpdateRelationshipType changes the suggested relationship type. +func (s *Suggestion) UpdateRelationshipType(relType string) error { + if relType == "" { + return fmt.Errorf("%w: relationship type is required", shared.ErrValidation) + } + s.relationshipType = relType + return nil +} + +// Dismiss marks the suggestion as dismissed. +func (s *Suggestion) Dismiss(reviewerID shared.ID) { + s.status = SuggestionDismissed + s.reviewedBy = &reviewerID + now := time.Now().UTC() + s.reviewedAt = &now +} + +// Accessors. + +func (s *Suggestion) ID() shared.ID { return s.id } +func (s *Suggestion) TenantID() shared.ID { return s.tenantID } +func (s *Suggestion) SourceAssetID() shared.ID { return s.sourceAssetID } +func (s *Suggestion) TargetAssetID() shared.ID { return s.targetAssetID } +func (s *Suggestion) RelationshipType() string { return s.relationshipType } +func (s *Suggestion) Reason() string { return s.reason } +func (s *Suggestion) Confidence() float64 { return s.confidence } +func (s *Suggestion) Status() string { return s.status } +func (s *Suggestion) ReviewedBy() *shared.ID { return s.reviewedBy } +func (s *Suggestion) ReviewedAt() *time.Time { return s.reviewedAt } +func (s *Suggestion) CreatedAt() time.Time { return s.createdAt } +func (s *Suggestion) SourceAssetName() string { return s.sourceAssetName } +func (s *Suggestion) SourceAssetType() string { return s.sourceAssetType } +func (s *Suggestion) TargetAssetName() string { return s.targetAssetName } +func (s *Suggestion) TargetAssetType() string { return s.targetAssetType } + +// SetAssetInfo sets enrichment fields (called by repository after JOIN). +func (s *Suggestion) SetAssetInfo(srcName, srcType, tgtName, tgtType string) { + s.sourceAssetName = srcName + s.sourceAssetType = srcType + s.targetAssetName = tgtName + s.targetAssetType = tgtType +} + +// SuggestionRepository defines the persistence interface for suggestions. +type SuggestionRepository interface { + Create(ctx context.Context, s *Suggestion) error + CreateBatch(ctx context.Context, suggestions []*Suggestion) (int, error) + GetByID(ctx context.Context, tenantID, id shared.ID) (*Suggestion, error) + ListPending(ctx context.Context, tenantID shared.ID, search string, page pagination.Pagination) (pagination.Result[*Suggestion], error) + UpdateStatus(ctx context.Context, s *Suggestion) error + CountPending(ctx context.Context, tenantID shared.ID) (int64, error) + DeleteByAssetID(ctx context.Context, tenantID, assetID shared.ID) error + DeletePending(ctx context.Context, tenantID shared.ID) error + ApproveAll(ctx context.Context, tenantID, reviewerID shared.ID) ([]*Suggestion, error) + UpdateRelationshipType(ctx context.Context, tenantID, id shared.ID, relType string) error +} + +// Errors. +var ( + ErrSuggestionNotFound = fmt.Errorf("%w: suggestion not found", shared.ErrNotFound) +) diff --git a/pkg/domain/remediation/campaign.go b/pkg/domain/remediation/campaign.go index 9f3068c7..1aa7b701 100644 --- a/pkg/domain/remediation/campaign.go +++ b/pkg/domain/remediation/campaign.go @@ -158,6 +158,30 @@ func (c *Campaign) SetTimeline(startDate, dueDate *time.Time) { c.updatedAt = time.Now() } +// SetName sets campaign name. +func (c *Campaign) SetName(name string) { + c.name = name + c.updatedAt = time.Now() +} + +// SetDescription sets campaign description. +func (c *Campaign) SetDescription(desc string) { + c.description = desc + c.updatedAt = time.Now() +} + +// SetPriority sets campaign priority. +func (c *Campaign) SetPriority(p CampaignPriority) { + c.priority = p + c.updatedAt = time.Now() +} + +// SetDueDate sets campaign due date. +func (c *Campaign) SetDueDate(d *time.Time) { + c.dueDate = d + c.updatedAt = time.Now() +} + // SetTags sets campaign tags. func (c *Campaign) SetTags(tags []string) { c.tags = tags diff --git a/pkg/domain/simulation/repository.go b/pkg/domain/simulation/repository.go index 94bc1d3f..a4181d6c 100644 --- a/pkg/domain/simulation/repository.go +++ b/pkg/domain/simulation/repository.go @@ -55,6 +55,26 @@ type ControlTestRepository interface { Delete(ctx context.Context, tenantID, id shared.ID) error List(ctx context.Context, filter ControlTestFilter, page pagination.Pagination) (pagination.Result[*ControlTest], error) GetStatsByFramework(ctx context.Context, tenantID shared.ID) ([]FrameworkStats, error) + + // ListOverdue returns control tests that have not been tested for at least staleDays days, + // or have never been tested (last_tested_at IS NULL), across all tenants. + // Used by the ControlTestSchedulerController to surface stale coverage. + // Returns up to limit results per call to avoid unbounded queries. + ListOverdue(ctx context.Context, staleDays int, limit int) ([]*OverdueControlTest, error) + + // MarkOverdue sets the status of a control test to 'untested' when it has gone + // past its review interval, signalling that detection coverage has lapsed. + // Security: tenantID enforces tenant isolation. + MarkOverdue(ctx context.Context, tenantID, id shared.ID) error +} + +// OverdueControlTest is returned by ListOverdue — includes the tenant ID for routing. +type OverdueControlTest struct { + TenantID shared.ID + ControlTestID shared.ID + Name string + Framework string + DaysSinceTested int // 0 if never tested (last_tested_at IS NULL) } // FrameworkStats holds aggregated control test statistics per framework. diff --git a/pkg/domain/vulnerability/repository.go b/pkg/domain/vulnerability/repository.go index e2509909..51515c09 100644 --- a/pkg/domain/vulnerability/repository.go +++ b/pkg/domain/vulnerability/repository.go @@ -353,6 +353,17 @@ type FindingRepository interface { // ListByStatusAndAssets returns findings with a specific status on specific assets. // Used by auto-verify: find fix_applied findings on assets that were just scanned. ListByStatusAndAssets(ctx context.Context, tenantID shared.ID, status FindingStatus, assetIDs []shared.ID) ([]*Finding, error) + + // UpdateWorkItemURIs updates the work_item_uris field for a finding. + // This is a targeted patch — does not modify other fields. + // Used by the Jira ticketing integration to persist ticket references. + // Security: tenantID enforces tenant isolation (IDOR prevention). + UpdateWorkItemURIs(ctx context.Context, tenantID, id shared.ID, uris []string) error + + // GetByWorkItemURI retrieves a finding that has a specific work item URI. + // Used by the Jira webhook receiver to map inbound status changes back to findings. + // Security: tenantID enforces tenant isolation. + GetByWorkItemURI(ctx context.Context, tenantID shared.ID, uri string) (*Finding, error) } // FindingGroup represents a group of findings aggregated by a dimension. diff --git a/tests/unit/asset_handler_test.go b/tests/unit/asset_handler_test.go index 598fe6f6..b19043e2 100644 --- a/tests/unit/asset_handler_test.go +++ b/tests/unit/asset_handler_test.go @@ -174,7 +174,7 @@ func (m *HandlerMockRepository) UpdateFindingCounts(ctx context.Context, tenantI return nil } -func (m *HandlerMockRepository) ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, limit int) ([]string, error) { +func (m *HandlerMockRepository) ListDistinctTags(ctx context.Context, tenantID shared.ID, prefix string, types []string, limit int) ([]string, error) { return []string{}, nil } @@ -194,7 +194,7 @@ func (m *HandlerMockRepository) BulkUpdateStatus(_ context.Context, _ shared.ID, return 0, nil } -func (m *HandlerMockRepository) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string) (*asset.AggregateStats, error) { +func (m *HandlerMockRepository) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string, _ ...string) (*asset.AggregateStats, error) { return &asset.AggregateStats{ ByType: make(map[string]int), ByStatus: make(map[string]int), @@ -208,6 +208,10 @@ func (m *HandlerMockRepository) GetPropertyFacets(_ context.Context, _ shared.ID return nil, nil } +func (m *HandlerMockRepository) ListAllNodes(_ context.Context, _ shared.ID) ([]asset.AssetNode, error) { + return nil, nil +} + func newTestHandler() *handler.AssetHandler { repo := NewHandlerMockRepository() log := logger.NewDevelopment() diff --git a/tests/unit/asset_relationship_service_test.go b/tests/unit/asset_relationship_service_test.go index b002dcfd..2c100061 100644 --- a/tests/unit/asset_relationship_service_test.go +++ b/tests/unit/asset_relationship_service_test.go @@ -196,6 +196,10 @@ func (m *MockRelationshipRepository) CountByType(_ context.Context, _ shared.ID) return out, nil } +func (m *MockRelationshipRepository) ListAllEdges(_ context.Context, _ shared.ID) ([]asset.RelationshipEdge, error) { + return nil, nil +} + // AddRelationshipWithAssets adds a pre-built RelationshipWithAssets to the mock store. func (m *MockRelationshipRepository) AddRelationshipWithAssets(rwa *asset.RelationshipWithAssets) { m.relationships[rwa.Relationship.ID().String()] = rwa.Relationship diff --git a/tests/unit/asset_service_test.go b/tests/unit/asset_service_test.go index ceb73c34..8c2b27b1 100644 --- a/tests/unit/asset_service_test.go +++ b/tests/unit/asset_service_test.go @@ -221,7 +221,7 @@ func (m *MockAssetRepository) UpdateFindingCounts(_ context.Context, _ shared.ID return nil } -func (m *MockAssetRepository) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ int) ([]string, error) { +func (m *MockAssetRepository) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ []string, _ int) ([]string, error) { return []string{}, nil } @@ -256,7 +256,7 @@ func (m *MockAssetRepository) BulkUpdateStatus(_ context.Context, _ shared.ID, i return updated, nil } -func (m *MockAssetRepository) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string) (*asset.AggregateStats, error) { +func (m *MockAssetRepository) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string, _ ...string) (*asset.AggregateStats, error) { return &asset.AggregateStats{ ByType: make(map[string]int), ByStatus: make(map[string]int), @@ -270,6 +270,10 @@ func (m *MockAssetRepository) GetPropertyFacets(_ context.Context, _ shared.ID, return nil, nil } +func (m *MockAssetRepository) ListAllNodes(_ context.Context, _ shared.ID) ([]asset.AssetNode, error) { + return nil, nil +} + // ============================================================================= // Mock Repository Extension Repository // ============================================================================= @@ -1948,7 +1952,7 @@ func TestAssetService_ListTags_Success(t *testing.T) { svc, _ := newTestService() tenantID := serviceTenantID.String() - tags, err := svc.ListTags(context.Background(), tenantID, "", 50) + tags, err := svc.ListTags(context.Background(), tenantID, "", nil, 50) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -1960,7 +1964,7 @@ func TestAssetService_ListTags_Success(t *testing.T) { func TestAssetService_ListTags_InvalidTenantID(t *testing.T) { svc, _ := newTestService() - _, err := svc.ListTags(context.Background(), "bad-uuid", "", 50) + _, err := svc.ListTags(context.Background(), "bad-uuid", "", nil, 50) if err == nil { t.Fatal("expected error for invalid tenant ID") } @@ -1974,17 +1978,17 @@ func TestAssetService_ListTags_DefaultLimit(t *testing.T) { tenantID := serviceTenantID.String() // Limit <= 0 should default to 50, limit > 100 should default to 50 - _, err := svc.ListTags(context.Background(), tenantID, "", 0) + _, err := svc.ListTags(context.Background(), tenantID, "", nil, 0) if err != nil { t.Fatalf("expected no error with zero limit, got %v", err) } - _, err = svc.ListTags(context.Background(), tenantID, "", -1) + _, err = svc.ListTags(context.Background(), tenantID, "", nil, -1) if err != nil { t.Fatalf("expected no error with negative limit, got %v", err) } - _, err = svc.ListTags(context.Background(), tenantID, "", 200) + _, err = svc.ListTags(context.Background(), tenantID, "", nil, 200) if err != nil { t.Fatalf("expected no error with over-limit, got %v", err) } diff --git a/tests/unit/attack_surface_service_test.go b/tests/unit/attack_surface_service_test.go index 83ca22ef..2431f413 100644 --- a/tests/unit/attack_surface_service_test.go +++ b/tests/unit/attack_surface_service_test.go @@ -160,7 +160,7 @@ func (m *mockAttackSurfaceRepo) UpdateFindingCounts(_ context.Context, _ shared. return nil } -func (m *mockAttackSurfaceRepo) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ int) ([]string, error) { +func (m *mockAttackSurfaceRepo) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ []string, _ int) ([]string, error) { return []string{}, nil } @@ -172,7 +172,7 @@ func (m *mockAttackSurfaceRepo) BulkUpdateStatus(_ context.Context, _ shared.ID, return 0, nil } -func (m *mockAttackSurfaceRepo) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string) (*asset.AggregateStats, error) { +func (m *mockAttackSurfaceRepo) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string, _ ...string) (*asset.AggregateStats, error) { return &asset.AggregateStats{ ByType: make(map[string]int), ByStatus: make(map[string]int), @@ -186,13 +186,64 @@ func (m *mockAttackSurfaceRepo) GetPropertyFacets(_ context.Context, _ shared.ID return nil, nil } +func (m *mockAttackSurfaceRepo) ListAllNodes(_ context.Context, _ shared.ID) ([]asset.AssetNode, error) { + return nil, nil +} + +// ============================================================================= +// Mock Relationship Repository for Attack Surface Service +// ============================================================================= + +type mockAttackSurfaceRelRepo struct{} + +func (m *mockAttackSurfaceRelRepo) Create(_ context.Context, _ *asset.Relationship) error { + return nil +} + +func (m *mockAttackSurfaceRelRepo) GetByID(_ context.Context, _, _ shared.ID) (*asset.RelationshipWithAssets, error) { + return nil, shared.ErrNotFound +} + +func (m *mockAttackSurfaceRelRepo) Update(_ context.Context, _ *asset.Relationship) error { + return nil +} + +func (m *mockAttackSurfaceRelRepo) Delete(_ context.Context, _, _ shared.ID) error { + return nil +} + +func (m *mockAttackSurfaceRelRepo) ListByAsset(_ context.Context, _, _ shared.ID, _ asset.RelationshipFilter) ([]*asset.RelationshipWithAssets, int64, error) { + return nil, 0, nil +} + +func (m *mockAttackSurfaceRelRepo) Exists(_ context.Context, _, _, _ shared.ID, _ asset.RelationshipType) (bool, error) { + return false, nil +} + +func (m *mockAttackSurfaceRelRepo) CountByAsset(_ context.Context, _, _ shared.ID) (int64, error) { + return 0, nil +} + +func (m *mockAttackSurfaceRelRepo) CreateBatchIgnoreConflicts(_ context.Context, _ []*asset.Relationship) (int, error) { + return 0, nil +} + +func (m *mockAttackSurfaceRelRepo) CountByType(_ context.Context, _ shared.ID) (map[asset.RelationshipType]int64, error) { + return nil, nil +} + +func (m *mockAttackSurfaceRelRepo) ListAllEdges(_ context.Context, _ shared.ID) ([]asset.RelationshipEdge, error) { + return nil, nil +} + // ============================================================================= // Helper Functions // ============================================================================= func newTestAttackSurfaceService(repo *mockAttackSurfaceRepo) *app.AttackSurfaceService { log := logger.NewNop() - return app.NewAttackSurfaceService(repo, log) + relRepo := &mockAttackSurfaceRelRepo{} + return app.NewAttackSurfaceService(repo, relRepo, log) } // makeAttackSurfaceAsset creates a test asset using Reconstitute with the given parameters. diff --git a/tests/unit/branch_lifecycle_test.go b/tests/unit/branch_lifecycle_test.go index 0458368c..0595b3ba 100644 --- a/tests/unit/branch_lifecycle_test.go +++ b/tests/unit/branch_lifecycle_test.go @@ -393,3 +393,11 @@ func (m *MockFindingRepoForLifecycle) FindRelatedCVEs(_ context.Context, _ share func (m *MockFindingRepoForLifecycle) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *MockFindingRepoForLifecycle) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *MockFindingRepoForLifecycle) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/data_scope_test.go b/tests/unit/data_scope_test.go index e97b1d9c..41cb1c3f 100644 --- a/tests/unit/data_scope_test.go +++ b/tests/unit/data_scope_test.go @@ -929,3 +929,11 @@ func (m *mockFindingRepoForScope) FindRelatedCVEs(_ context.Context, _ shared.ID func (m *mockFindingRepoForScope) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *mockFindingRepoForScope) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *mockFindingRepoForScope) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/finding_activity_service_test.go b/tests/unit/finding_activity_service_test.go index 445a50d8..2f6cce16 100644 --- a/tests/unit/finding_activity_service_test.go +++ b/tests/unit/finding_activity_service_test.go @@ -1616,3 +1616,11 @@ func (m *stubFindingRepo) FindRelatedCVEs(_ context.Context, _ shared.ID, _ stri func (m *stubFindingRepo) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *stubFindingRepo) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *stubFindingRepo) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/finding_approval_service_test.go b/tests/unit/finding_approval_service_test.go index 91e41d38..ecc86108 100644 --- a/tests/unit/finding_approval_service_test.go +++ b/tests/unit/finding_approval_service_test.go @@ -1152,3 +1152,9 @@ func (m *mockFindingRepository) FindRelatedCVEs(_ context.Context, _ shared.ID, func (m *mockFindingRepository) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } +func (m *mockFindingRepository) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} +func (m *mockFindingRepository) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/pentest_service_test.go b/tests/unit/pentest_service_test.go index 408355a2..ac96b191 100644 --- a/tests/unit/pentest_service_test.go +++ b/tests/unit/pentest_service_test.go @@ -1487,3 +1487,11 @@ func (m *mockUnifiedFindingRepo) FindRelatedCVEs(_ context.Context, _ shared.ID, func (m *mockUnifiedFindingRepo) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *mockUnifiedFindingRepo) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *mockUnifiedFindingRepo) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/promote_properties_test.go b/tests/unit/promote_properties_test.go index adc9a8ef..0e3e81d1 100644 --- a/tests/unit/promote_properties_test.go +++ b/tests/unit/promote_properties_test.go @@ -149,6 +149,63 @@ func TestPromoteKnownProperties_RemoveColumnNames(t *testing.T) { assert.Equal(t, "should-stay", result.Properties["vendor"]) } +func TestPromoteKnownProperties_CamelToSnakeNormalization(t *testing.T) { + input := app.CreateAssetInput{ + Name: "srv-01", + Type: "host", + Properties: map[string]any{ + "cpuCores": 8, + "memoryGB": 32, + "osVersion": "22.04", + "isVirtual": true, + "openPorts": []any{"22", "80"}, + "apiType": "REST", + "baseUrl": "https://example.com", + "vendor": "Dell", // already snake_case — stays + "record_type": "A", // already snake_case — stays + }, + } + + result := app.PromoteKnownProperties(input) + + // camelCase keys converted to snake_case + assert.Equal(t, 8, result.Properties["cpu_cores"]) + assert.Equal(t, 32, result.Properties["memory_gb"]) + assert.Equal(t, "22.04", result.Properties["os_version"]) + assert.Equal(t, true, result.Properties["is_virtual"]) + assert.Equal(t, []any{"22", "80"}, result.Properties["open_ports"]) + assert.Equal(t, "REST", result.Properties["api_type"]) + assert.Equal(t, "https://example.com", result.Properties["base_url"]) + + // Already snake_case or single-word keys unchanged + assert.Equal(t, "Dell", result.Properties["vendor"]) + assert.Equal(t, "A", result.Properties["record_type"]) + + // Old camelCase keys removed + assert.Nil(t, result.Properties["cpuCores"]) + assert.Nil(t, result.Properties["memoryGB"]) + assert.Nil(t, result.Properties["osVersion"]) + assert.Nil(t, result.Properties["isVirtual"]) + assert.Nil(t, result.Properties["openPorts"]) +} + +func TestPromoteKnownProperties_CamelSnakeDuplicate(t *testing.T) { + // When both camelCase and snake_case exist, prefer snake_case + input := app.CreateAssetInput{ + Name: "srv-01", + Type: "host", + Properties: map[string]any{ + "cpu_cores": 16, // snake_case (should win) + "cpuCores": 8, // camelCase (should be dropped) + }, + } + + result := app.PromoteKnownProperties(input) + + assert.Equal(t, 16, result.Properties["cpu_cores"]) + assert.Nil(t, result.Properties["cpuCores"]) +} + func TestPromoteKnownProperties_EmptyProperties(t *testing.T) { input := app.CreateAssetInput{ Name: "srv-01", diff --git a/tests/unit/scope_service_test.go b/tests/unit/scope_service_test.go index 3c149009..1f502e7a 100644 --- a/tests/unit/scope_service_test.go +++ b/tests/unit/scope_service_test.go @@ -314,7 +314,7 @@ func (m *mockAssetRepo) UpdateFindingCounts(_ context.Context, _ shared.ID, _ [] return nil } -func (m *mockAssetRepo) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ int) ([]string, error) { +func (m *mockAssetRepo) ListDistinctTags(_ context.Context, _ shared.ID, _ string, _ []string, _ int) ([]string, error) { return []string{}, nil } @@ -334,7 +334,7 @@ func (m *mockAssetRepo) BulkUpdateStatus(_ context.Context, _ shared.ID, _ []sha return 0, nil } -func (m *mockAssetRepo) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string) (*asset.AggregateStats, error) { +func (m *mockAssetRepo) GetAggregateStats(_ context.Context, _ shared.ID, _ []string, _ []string, _ string, _ ...string) (*asset.AggregateStats, error) { return &asset.AggregateStats{ ByType: make(map[string]int), ByStatus: make(map[string]int), @@ -348,6 +348,10 @@ func (m *mockAssetRepo) GetPropertyFacets(_ context.Context, _ shared.ID, _ []st return nil, nil } +func (m *mockAssetRepo) ListAllNodes(_ context.Context, _ shared.ID) ([]asset.AssetNode, error) { + return nil, nil +} + // ============================================================================= // Helpers // ============================================================================= diff --git a/tests/unit/vulnerability_service_test.go b/tests/unit/vulnerability_service_test.go index 790065f8..a01accf3 100644 --- a/tests/unit/vulnerability_service_test.go +++ b/tests/unit/vulnerability_service_test.go @@ -3572,3 +3572,11 @@ func (m *mockFindingRepo) FindRelatedCVEs(_ context.Context, _ shared.ID, _ stri func (m *mockFindingRepo) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *mockFindingRepo) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *mockFindingRepo) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +} diff --git a/tests/unit/workflow_action_handlers_test.go b/tests/unit/workflow_action_handlers_test.go index 7c24e418..4c8843f9 100644 --- a/tests/unit/workflow_action_handlers_test.go +++ b/tests/unit/workflow_action_handlers_test.go @@ -1331,3 +1331,11 @@ func (m *wfActionMockFindingRepo) FindRelatedCVEs(_ context.Context, _ shared.ID func (m *wfActionMockFindingRepo) ListByStatusAndAssets(_ context.Context, _ shared.ID, _ vulnerability.FindingStatus, _ []shared.ID) ([]*vulnerability.Finding, error) { return nil, nil } + +func (m *wfActionMockFindingRepo) GetByWorkItemURI(_ context.Context, _ shared.ID, _ string) (*vulnerability.Finding, error) { + return nil, nil +} + +func (m *wfActionMockFindingRepo) UpdateWorkItemURIs(_ context.Context, _, _ shared.ID, _ []string) error { + return nil +}