diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c8be0a0..e845df4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,12 +62,10 @@ jobs: with: go-version: '1.26.x' - # TODO: Remove continue-on-error once pre-existing lint backlog is resolved - name: Install golangci-lint run: | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.11.4 - name: Run golangci-lint - continue-on-error: true run: golangci-lint run ./... shellcheck: name: Shellcheck diff --git a/.gitignore b/.gitignore index 30b494a..7801bff 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ lab/.terraform/ # IDE/Tool .worktrees/ +bin/ +fileserver diff --git a/cmd/fileserver/main.go b/cmd/fileserver/main.go index f3c0a7d..fc8b85d 100644 --- a/cmd/fileserver/main.go +++ b/cmd/fileserver/main.go @@ -123,13 +123,27 @@ func main() { log.Printf("TLS enabled with certificate: %s", *tlsCert) log.Printf("Access URL: https://%s", addr) - if err := http.ListenAndServeTLS(addr, *tlsCert, *tlsKey, handler); err != nil { + server := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } + if err := server.ListenAndServeTLS(*tlsCert, *tlsKey); err != nil { log.Fatalf("ERROR: HTTPS server failed: %v", err) } } else { log.Printf("WARNING: Running without TLS encryption - credentials transmitted in cleartext") log.Printf("Access URL: http://%s", addr) - if err := http.ListenAndServe(addr, handler); err != nil { + server := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: 60 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } + if err := server.ListenAndServe(); err != nil { log.Fatalf("ERROR: HTTP server failed: %v", err) } } @@ -145,7 +159,7 @@ func basicAuthMiddleware(creds *auth.Credentials, next http.Handler) http.Handle } if !creds.Authenticate(username, password) { - log.Printf("AUTH FAILED: %s from %s", username, r.RemoteAddr) + log.Printf("AUTH FAILED: %s from %s", sanitizeLogString(username), sanitizeLogString(r.RemoteAddr)) //nolint:gosec // G706: sanitized sendAuthRequired(w) return } @@ -159,7 +173,7 @@ func basicAuthMiddleware(creds *auth.Credentials, next http.Handler) http.Handle func sendAuthRequired(w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", `Basic realm="NetUtil File Server"`) w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("401 Unauthorized\n")) + _, _ = w.Write([]byte("401 Unauthorized\n")) } // loggingMiddleware logs all requests @@ -176,12 +190,12 @@ func loggingMiddleware(next http.Handler) http.Handler { next.ServeHTTP(wrapped, r) duration := time.Since(start) - log.Printf("%s %s %s %d %s %s", - r.RemoteAddr, - username, - r.Method, + log.Printf("%s %s %s %d %s %s", //nolint:gosec // G706: sanitized via sanitizeLogString + sanitizeLogString(r.RemoteAddr), + sanitizeLogString(username), + sanitizeLogString(r.Method), wrapped.statusCode, - r.URL.Path, + sanitizeLogString(r.URL.Path), duration, ) }) @@ -215,15 +229,13 @@ func corsMiddleware(next http.Handler) http.Handler { }) } -// cleanPath ensures the path is clean and doesn't escape workspace -func cleanPath(requestPath string) string { - // Clean the path to remove .. and . elements - cleaned := filepath.Clean("/" + requestPath) - // Ensure it starts with / - if !strings.HasPrefix(cleaned, "/") { - cleaned = "/" + cleaned - } - - return cleaned +// sanitizeLogString strips control characters from a string to prevent log injection. +func sanitizeLogString(s string) string { + return strings.Map(func(r rune) rune { + if r < 32 && r != '\t' { + return -1 + } + return r + }, s) } diff --git a/cmd/netutil/main.go b/cmd/netutil/main.go index fb35767..eff1c62 100644 --- a/cmd/netutil/main.go +++ b/cmd/netutil/main.go @@ -38,12 +38,12 @@ func ensureExecutable(execDir, scriptsDir string) { } // Make all .sh files in scripts/ (recursively) executable - _ = filepath.Walk(scriptsDir, func(path string, info os.FileInfo, err error) error { + _ = filepath.Walk(scriptsDir, func(path string, info os.FileInfo, err error) error { //nolint:gosec // G122: air-gapped tool if err != nil || info.IsDir() { return nil } if filepath.Ext(path) == ".sh" && info.Mode()&0111 == 0 { - _ = os.Chmod(path, info.Mode()|0755) + _ = os.Chmod(path, info.Mode()|0755) //nolint:gosec // G122: air-gapped tool } return nil }) @@ -70,64 +70,22 @@ func main() { } func run() int { - // Define command-line flags - scriptsDirFlag := flag.String("scripts-dir", "", "Path to scripts directory (default: next to executable)") - showVersion := flag.Bool("version", false, "Show version") - flag.Parse() - - if *showVersion { - fmt.Printf("NetUtility %s\n", version) - return 0 - } - - // Determine scripts directory - scriptsDir := getDefaultScriptsDir() - if *scriptsDirFlag != "" { - scriptsDir = *scriptsDirFlag + scriptsDir, shouldExit, exitCode := parseCLIFlags() + if shouldExit { + return exitCode } // Ensure bin/ and scripts/ are executable (idempotent, silent) execPath, _ := os.Executable() ensureExecutable(filepath.Dir(execPath), scriptsDir) - // Load configuration - cfg, err := config.LoadConfig() - if err != nil { - fmt.Fprintf(os.Stderr, "Warning: Failed to load config: %v\n", err) - cfg = config.GetDefaultConfig() - } - - // Validate and sanitize configuration - if err := cfg.ValidateConfig(); err != nil { - fmt.Fprintf(os.Stderr, "Warning: Configuration validation failed: %v\n", err) - fmt.Fprintf(os.Stderr, "Sanitizing configuration...\n") - cfg.SanitizeConfig() - - // Save sanitized config - if saveErr := cfg.SaveConfig(); saveErr != nil { - fmt.Fprintf(os.Stderr, "Warning: Failed to save sanitized config: %v\n", saveErr) - } - } - - // Check if this is first-time setup - if cfg.NeedsFirstTimeSetup() { - fmt.Println("=== Welcome to NetUtility ===") - fmt.Println("First-time setup required.") - fmt.Println() - fmt.Println("NetUtility needs a workspace directory to store:") - fmt.Println(" • Network captures and analysis results") - fmt.Println(" • Vulnerability scan data") - fmt.Println(" • Configuration backups") - fmt.Println(" • Log files") - fmt.Println() - - if err := runFirstTimeSetup(cfg); err != nil { - fmt.Fprintf(os.Stderr, "Setup failed: %v\n", err) - return 1 - } + // Load and validate configuration + cfg := loadAndSanitizeConfig() - fmt.Println("Setup complete! Starting NetUtility...") - fmt.Println() + // Handle first-time setup if needed + if err := handleFirstTimeSetup(cfg); err != nil { + fmt.Fprintf(os.Stderr, "Setup failed: %v\n", err) + return 1 } // Initialize script registry with resolved scripts directory @@ -138,15 +96,7 @@ func run() int { } // Set up workspace environment - if cfg.IsWorkspaceConfigured() { - // Ensure workspace is writable (handles creation and ownership) - if err := cfg.EnsureWorkspaceWritable(); err != nil { - fmt.Fprintf(os.Stderr, "Warning: Failed to ensure workspace is writable: %v\n", err) - } else { - // Set NETUTIL_WORKDIR environment variable - os.Setenv("NETUTIL_WORKDIR", cfg.WorkspaceDir) - } - } + setupWorkspaceEnv(cfg) // Get remaining arguments after flag parsing args := flag.Args() @@ -191,6 +141,89 @@ func run() int { return 0 } +// parseCLIFlags parses command-line flags and returns the resolved scripts directory. +// If the program should exit early (e.g., --version), shouldExit is true and +// exitCode contains the appropriate code. +func parseCLIFlags() (scriptsDir string, shouldExit bool, exitCode int) { + scriptsDirFlag := flag.String("scripts-dir", "", "Path to scripts directory (default: next to executable)") + showVersion := flag.Bool("version", false, "Show version") + flag.Parse() + + if *showVersion { + fmt.Printf("NetUtility %s\n", version) + return "", true, 0 + } + + scriptsDir = getDefaultScriptsDir() + if *scriptsDirFlag != "" { + scriptsDir = *scriptsDirFlag + } + return scriptsDir, false, 0 +} + +// loadAndSanitizeConfig loads the application configuration, falling back to +// defaults on failure. If validation fails, the config is sanitized and saved. +func loadAndSanitizeConfig() *config.Config { + cfg, err := config.LoadConfig() + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to load config: %v\n", err) + cfg = config.GetDefaultConfig() + } + + if err := cfg.ValidateConfig(); err != nil { + fmt.Fprintf(os.Stderr, "Warning: Configuration validation failed: %v\n", err) + fmt.Fprintf(os.Stderr, "Sanitizing configuration...\n") + cfg.SanitizeConfig() + + if saveErr := cfg.SaveConfig(); saveErr != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to save sanitized config: %v\n", saveErr) + } + } + + return cfg +} + +// handleFirstTimeSetup runs the initial workspace configuration if needed. +// Returns nil if no setup is needed or if setup succeeds. +func handleFirstTimeSetup(cfg *config.Config) error { + if !cfg.NeedsFirstTimeSetup() { + return nil + } + + fmt.Println("=== Welcome to NetUtility ===") + fmt.Println("First-time setup required.") + fmt.Println() + fmt.Println("NetUtility needs a workspace directory to store:") + fmt.Println(" • Network captures and analysis results") + fmt.Println(" • Vulnerability scan data") + fmt.Println(" • Configuration backups") + fmt.Println(" • Log files") + fmt.Println() + + if err := runFirstTimeSetup(cfg); err != nil { + return err + } + + fmt.Println("Setup complete! Starting NetUtility...") + fmt.Println() + return nil +} + +// setupWorkspaceEnv ensures the workspace directory is writable and sets the +// NETUTIL_WORKDIR environment variable for child processes. +func setupWorkspaceEnv(cfg *config.Config) { + if !cfg.IsWorkspaceConfigured() { + return + } + if err := cfg.EnsureWorkspaceWritable(); err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to ensure workspace is writable: %v\n", err) + return + } + if err := os.Setenv("NETUTIL_WORKDIR", cfg.WorkspaceDir); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to set NETUTIL_WORKDIR: %v\n", err) + } +} + // Command mappings for CLI shortcuts var commandMappings = map[string]ScriptInfo{ // Discovery @@ -501,7 +534,7 @@ func runScriptDirect(scriptPath string, scriptName string) bool { // Ask user to press enter to continue (TUI mode) fmt.Printf("\nPress Enter to return to menu...") - fmt.Scanln() + _, _ = fmt.Scanln() // Clear screen again fmt.Print("\033[2J\033[H") diff --git a/internal/app/privileges.go b/internal/app/privileges.go index d67e600..1b2cbaf 100644 --- a/internal/app/privileges.go +++ b/internal/app/privileges.go @@ -25,7 +25,7 @@ func escalatePrivileges() error { return fmt.Errorf("failed to get executable path: %w", err) } - cmd := exec.Command("sudo", append([]string{executable}, os.Args[1:]...)...) + cmd := exec.Command("sudo", append([]string{executable}, os.Args[1:]...)...) //nolint:gosec // G702: executable from os.Executable cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go index 5bf9fda..af25008 100644 --- a/internal/auth/credentials.go +++ b/internal/auth/credentials.go @@ -18,11 +18,11 @@ type Credentials struct { // File format: username:bcrypt_hash (one per line) // Lines starting with # are comments, blank lines are ignored func LoadCredentials(path string) (*Credentials, error) { - file, err := os.Open(path) + file, err := os.Open(path) //nolint:gosec // G304: path from CLI flag/config if err != nil { return nil, fmt.Errorf("failed to open credentials file: %w", err) } - defer file.Close() + defer func() { _ = file.Close() }() creds := &Credentials{ users: make(map[string]string), diff --git a/internal/config/config.go b/internal/config/config.go index 0438b96..d128d2c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -82,7 +82,7 @@ func LoadConfig() (*Config, error) { return GetDefaultConfig(), nil } - data, err := os.ReadFile(configPath) + data, err := os.ReadFile(configPath) //nolint:gosec // G304: configPath from executable path if err != nil { return nil, fmt.Errorf("failed to read config file: %w", err) } @@ -126,7 +126,7 @@ func (c *Config) SaveConfig() error { return fmt.Errorf("failed to marshal config: %w", err) } - if err := os.WriteFile(configPath, data, 0644); err != nil { + if err := os.WriteFile(configPath, data, 0600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } @@ -200,7 +200,7 @@ func (c *Config) CreateWorkspace() error { } // Create main workspace directory with permissions that allow root access - if err := os.MkdirAll(c.WorkspaceDir, 0755); err != nil { + if err := os.MkdirAll(c.WorkspaceDir, 0750); err != nil { return fmt.Errorf("failed to create workspace directory: %w", err) } @@ -218,14 +218,14 @@ func (c *Config) CreateWorkspace() error { for _, subdir := range subdirs { path := filepath.Join(c.WorkspaceDir, subdir) // Use 0777 permissions so root can write to user's workspace - if err := os.MkdirAll(path, 0777); err != nil { + if err := os.MkdirAll(path, 0750); err != nil { return fmt.Errorf("failed to create subdirectory %s: %w", subdir, err) } } // Create symbolic links for latest results latestDir := filepath.Join(c.WorkspaceDir, "latest") - if err := os.MkdirAll(latestDir, 0777); err != nil { + if err := os.MkdirAll(latestDir, 0750); err != nil { return fmt.Errorf("failed to create latest directory: %w", err) } @@ -314,10 +314,10 @@ func isValidInterfaceName(name string) bool { // Interface names should contain only alphanumeric characters, dots, and hyphens for _, char := range name { - if !((char >= 'a' && char <= 'z') || - (char >= 'A' && char <= 'Z') || - (char >= '0' && char <= '9') || - char == '.' || char == '-' || char == '_') { + if (char < 'a' || char > 'z') && + (char < 'A' || char > 'Z') && + (char < '0' || char > '9') && + char != '.' && char != '-' && char != '_' { return false } } @@ -530,14 +530,14 @@ func (c *Config) FixWorkspacePermissions() error { continue // Skip non-existent directories } - // Set permissions to 0777 so root can write to user-owned directories - if err := os.Chmod(dirPath, 0777); err != nil { + // Set permissions to 0750 so root can write to user-owned directories + if err := os.Chmod(dirPath, 0750); err != nil { //nolint:gosec // G302: directory chmod fmt.Fprintf(os.Stderr, "Warning: Failed to set permissions on %s: %v\n", dirPath, err) } } // Also fix permissions on the main workspace directory - if err := os.Chmod(c.WorkspaceDir, 0777); err != nil { + if err := os.Chmod(c.WorkspaceDir, 0750); err != nil { //nolint:gosec // G302: directory chmod fmt.Fprintf(os.Stderr, "Warning: Failed to set permissions on workspace root: %v\n", err) } @@ -567,12 +567,12 @@ func (c *Config) EnsureWorkspaceWritable() error { // Test write access by creating a temporary file testFile := filepath.Join(c.WorkspaceDir, ".netutil_write_test") - if err := os.WriteFile(testFile, []byte("test"), 0644); err != nil { + if err := os.WriteFile(testFile, []byte("test"), 0600); err != nil { return fmt.Errorf("workspace not writable: %w", err) } // Clean up test file - os.Remove(testFile) + _ = os.Remove(testFile) return nil } @@ -600,7 +600,7 @@ func FixWorkspaceOwnershipForPath(dir string) { return } - filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { if err != nil { return nil } @@ -608,5 +608,7 @@ func FixWorkspaceOwnershipForPath(dir string) { fmt.Fprintf(os.Stderr, "Warning: Failed to change ownership of %s: %v\n", path, chownErr) } return nil - }) + }); err != nil { + fmt.Fprintf(os.Stderr, "Warning: error walking %s for chown: %v\n", dir, err) + } } diff --git a/internal/correlation/compliance_checks.go b/internal/correlation/compliance_checks.go index 07bd3fc..7fb874a 100644 --- a/internal/correlation/compliance_checks.go +++ b/internal/correlation/compliance_checks.go @@ -36,8 +36,8 @@ func worstSeverity(findings []ComplianceFinding) string { worst := "" for _, f := range findings { switch f.Severity { - case "critical": - return "critical" + case severityCritical: + return severityCritical case "warning": worst = "warning" case "ok": @@ -72,7 +72,7 @@ var reTelnetEnabled = regexp.MustCompile(`(?m)^\s*transport input\s+(telnet|all) func checkTelnet(cfg string) ComplianceFinding { if reTelnetEnabled.MatchString(cfg) { - return ComplianceFinding{"Telnet enabled", "critical", "VTY line allows telnet (transport input telnet/all)"} + return ComplianceFinding{"Telnet enabled", severityCritical, "VTY line allows telnet (transport input telnet/all)"} } return ComplianceFinding{"Telnet disabled", "ok", ""} } @@ -84,7 +84,7 @@ func checkHTTPServer(cfg string) ComplianceFinding { for _, l := range lines { t := strings.TrimSpace(l) if t == "ip http server" { - return ComplianceFinding{"HTTP management server enabled", "critical", "\"ip http server\" enables unencrypted management access"} + return ComplianceFinding{"HTTP management server enabled", severityCritical, "\"ip http server\" enables unencrypted management access"} } } return ComplianceFinding{"HTTP management server disabled", "ok", ""} @@ -95,7 +95,7 @@ var reDefaultSNMP = regexp.MustCompile(`(?mi)^\s*snmp-server community\s+(public func checkDefaultSNMP(cfg string) ComplianceFinding { if reDefaultSNMP.MatchString(cfg) { - return ComplianceFinding{"Default SNMP community", "critical", "SNMP community \"public\" or \"private\" configured"} + return ComplianceFinding{"Default SNMP community", severityCritical, "SNMP community \"public\" or \"private\" configured"} } return ComplianceFinding{"SNMP community strings", "ok", ""} } @@ -237,13 +237,13 @@ func checkHPComware(cfg string) []ComplianceFinding { var out []ComplianceFinding // Telnet: look for "user-interface vty" block with "protocol inbound telnet" if reComwareTelnet.MatchString(cfg) { - out = append(out, ComplianceFinding{"Telnet enabled", "critical", "\"protocol inbound telnet\" in user-interface vty"}) + out = append(out, ComplianceFinding{"Telnet enabled", severityCritical, "\"protocol inbound telnet\" in user-interface vty"}) } else { out = append(out, ComplianceFinding{"Telnet disabled", "ok", ""}) } // SNMP default communities if reComwareSNMP.MatchString(cfg) { - out = append(out, ComplianceFinding{"Default SNMP community", "critical", "SNMP community \"public\" or \"private\" configured"}) + out = append(out, ComplianceFinding{"Default SNMP community", severityCritical, "SNMP community \"public\" or \"private\" configured"}) } else { out = append(out, ComplianceFinding{"SNMP community strings", "ok", ""}) } @@ -274,13 +274,13 @@ func checkArubaCX(cfg string) []ComplianceFinding { var out []ComplianceFinding // Telnet if reArubaCXTelnet.MatchString(cfg) { - out = append(out, ComplianceFinding{"Telnet enabled", "critical", "\"telnet-server\" found in running config"}) + out = append(out, ComplianceFinding{"Telnet enabled", severityCritical, "\"telnet-server\" found in running config"}) } else { out = append(out, ComplianceFinding{"Telnet disabled", "ok", ""}) } // SNMP default communities if reArubaCXSNMP.MatchString(cfg) { - out = append(out, ComplianceFinding{"Default SNMP community", "critical", "SNMP community \"public\" or \"private\" configured"}) + out = append(out, ComplianceFinding{"Default SNMP community", severityCritical, "SNMP community \"public\" or \"private\" configured"}) } else { out = append(out, ComplianceFinding{"SNMP community strings", "ok", ""}) } @@ -311,13 +311,13 @@ func checkArubaSwitch(cfg string) []ComplianceFinding { var out []ComplianceFinding // Telnet if reArubaSwitchTelnet.MatchString(cfg) { - out = append(out, ComplianceFinding{"Telnet enabled", "critical", "\"telnet-server\" or \"ip telnet\" found in running config"}) + out = append(out, ComplianceFinding{"Telnet enabled", severityCritical, "\"telnet-server\" or \"ip telnet\" found in running config"}) } else { out = append(out, ComplianceFinding{"Telnet disabled", "ok", ""}) } // SNMP default communities if reArubaSwitchSNMP.MatchString(cfg) { - out = append(out, ComplianceFinding{"Default SNMP community", "critical", "SNMP community \"public\" or \"private\" configured"}) + out = append(out, ComplianceFinding{"Default SNMP community", severityCritical, "SNMP community \"public\" or \"private\" configured"}) } else { out = append(out, ComplianceFinding{"SNMP community strings", "ok", ""}) } @@ -364,7 +364,7 @@ func checkGeneric(cfg string) []ComplianceFinding { } // Default SNMP communities (keyword pattern works across vendors) if reGenericSNMP.MatchString(cfg) { - out = append(out, ComplianceFinding{"Default SNMP community", "critical", "\"public\" or \"private\" SNMP community string detected"}) + out = append(out, ComplianceFinding{"Default SNMP community", severityCritical, "\"public\" or \"private\" SNMP community string detected"}) } return out } diff --git a/internal/correlation/concurrency_test.go b/internal/correlation/concurrency_test.go index 5068969..4bbcaaa 100644 --- a/internal/correlation/concurrency_test.go +++ b/internal/correlation/concurrency_test.go @@ -17,7 +17,7 @@ func TestConcurrentAddScanResult(t *testing.T) { go func(idx int) { defer wg.Done() result := &ScanResult{ - ID: string(ScanTypeNetworkEnum) + "-" + string(rune('A'+idx)), + ID: string(ScanTypeNetworkEnum) + "-" + string(rune('A'+idx%26)), //nosec G115 — modulo keeps rune in range Type: ScanTypeNetworkEnum, Hosts: []Host{ {IP: "192.168.1." + string(rune('0'+idx%10))}, diff --git a/internal/correlation/config_enricher.go b/internal/correlation/config_enricher.go index ab3780b..740c0af 100644 --- a/internal/correlation/config_enricher.go +++ b/internal/correlation/config_enricher.go @@ -106,11 +106,7 @@ func (ce *ConfigEnricher) applyPhysicalLinks(idx *macIndex, correlations map[str if !ok { continue } - corr.PhysicalLinks = append(corr.PhysicalLinks, PhysicalLink{ - SwitchIP: entry.SwitchIP, - Interface: entry.Interface, - VLAN: entry.VLAN, - }) + corr.PhysicalLinks = append(corr.PhysicalLinks, PhysicalLink(entry)) } } @@ -133,7 +129,7 @@ func readMetadata(deviceDir string) (ip, vendor string) { // readFileContents reads a file and returns its content as a string. func readFileContents(path string) (string, error) { - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // G304: path from trusted workspace if err != nil { return "", err } diff --git a/internal/correlation/config_enricher_test.go b/internal/correlation/config_enricher_test.go index b33b168..7853fe9 100644 --- a/internal/correlation/config_enricher_test.go +++ b/internal/correlation/config_enricher_test.go @@ -178,11 +178,11 @@ func TestConfigEnricher_HasConfigs_Empty(t *testing.T) { func TestConfigEnricher_HasConfigs_WithDir(t *testing.T) { dir := t.TempDir() deviceDir := filepath.Join(dir, "configs", "10.0.0.1") - if err := os.MkdirAll(deviceDir, 0o755); err != nil { + if err := os.MkdirAll(deviceDir, 0750); err != nil { t.Fatal(err) } if err := os.WriteFile(filepath.Join(deviceDir, "metadata.txt"), - []byte("IP Address: 10.0.0.1\nVendor/OS: cisco_ios\n"), 0o644); err != nil { + []byte("IP Address: 10.0.0.1\nVendor/OS: cisco_ios\n"), 0600); err != nil { t.Fatal(err) } ce := NewConfigEnricher(dir) @@ -194,23 +194,23 @@ func TestConfigEnricher_HasConfigs_WithDir(t *testing.T) { func TestConfigEnricher_Enrich_ComplianceAndMAC(t *testing.T) { dir := t.TempDir() deviceDir := filepath.Join(dir, "configs", "10.0.0.1") - if err := os.MkdirAll(deviceDir, 0o755); err != nil { + if err := os.MkdirAll(deviceDir, 0750); err != nil { t.Fatal(err) } // Write metadata if err := os.WriteFile(filepath.Join(deviceDir, "metadata.txt"), - []byte("IP Address: 10.0.0.1\nVendor/OS: cisco_ios\nHostname: SW-CORE\n"), 0o644); err != nil { + []byte("IP Address: 10.0.0.1\nVendor/OS: cisco_ios\nHostname: SW-CORE\n"), 0600); err != nil { t.Fatal(err) } // Running config with telnet enabled if err := os.WriteFile(filepath.Join(deviceDir, "running_config.txt"), - []byte("hostname SW-CORE\nline vty 0 4\n transport input telnet\n"), 0o644); err != nil { + []byte("hostname SW-CORE\nline vty 0 4\n transport input telnet\n"), 0600); err != nil { t.Fatal(err) } // Compliance output with MAC table entry for a host if err := os.WriteFile(filepath.Join(deviceDir, "compliance_commands.txt"), - []byte("show mac address-table\n 10 aabb.cc00.0200 DYNAMIC Gi0/5\n"), 0o644); err != nil { + []byte("show mac address-table\n 10 aabb.cc00.0200 DYNAMIC Gi0/5\n"), 0600); err != nil { t.Fatal(err) } diff --git a/internal/correlation/correlator.go b/internal/correlation/correlator.go index 95ffaf9..8686b33 100644 --- a/internal/correlation/correlator.go +++ b/internal/correlation/correlator.go @@ -15,6 +15,15 @@ import ( "time" ) +const scanTypeSSLScan = "sslscan" + +const ( + severityCritical = "critical" + severityHigh = "high" + severityMedium = "medium" + subtypeSwitch = "switch" +) + // ScanType represents different types of network scans type ScanType string @@ -308,76 +317,13 @@ func (c *Correlator) correlateHost(hostIP string) { } // Merge screenshot metadata - if screenshotsData, ok := result.Metadata["screenshots"]; ok { - var screenshots []ScreenshotInfo - switch v := screenshotsData.(type) { - case []ScreenshotInfo: - screenshots = v - case []map[string]string: - // Produced by MergeScreenshotsIntoCorrelation on a second pass - for _, ss := range v { - screenshots = append(screenshots, ScreenshotInfo{ - IP: hostIP, - URL: ss["url"], - File: ss["file"], - StatusCode: ss["status_code"], - }) - } - case []interface{}: - for _, item := range v { - switch s := item.(type) { - case ScreenshotInfo: - screenshots = append(screenshots, s) - case map[string]interface{}: - ss := ScreenshotInfo{ - IP: getStringFromMap(s, "ip"), - URL: getStringFromMap(s, "url"), - File: getStringFromMap(s, "file"), - StatusCode: getStringFromMap(s, "status_code"), - } - screenshots = append(screenshots, ss) - } - } - } - - var hostScreenshots []ScreenshotInfo - for _, ss := range screenshots { - if ss.IP == hostIP { - hostScreenshots = append(hostScreenshots, ss) - } - } - - if len(hostScreenshots) > 0 { - MergeScreenshotsIntoCorrelation(correlation, hostScreenshots) - } - } + c.mergeScreenshotsFromResult(correlation, result, hostIP) // Collect services (merge across multiple scan results, prefer more detail) - for _, service := range result.Services { - if service.Host != hostIP { - continue - } - key := service.Host + "|" + strconv.Itoa(service.Port) + "|" + service.Protocol - if idx, exists := seenServices[key]; exists { - c.mergeService(&correlation.Services[idx], service) - continue - } - seenServices[key] = len(correlation.Services) - correlation.Services = append(correlation.Services, service) - } + c.collectServices(correlation, result, hostIP, seenServices) // Collect vulnerabilities (deduplicate across multiple scan results) - for _, vuln := range result.Vulnerabilities { - if vuln.Host != hostIP { - continue - } - key := vuln.Host + "|" + vuln.Title + "|" + vuln.Source + "|" + strconv.Itoa(vuln.Port) - if seenVulns[key] { - continue - } - seenVulns[key] = true - correlation.Vulnerabilities = append(correlation.Vulnerabilities, vuln) - } + c.collectVulnerabilities(correlation, result, hostIP, seenVulns) } } @@ -399,6 +345,90 @@ func (c *Correlator) correlateHost(hostIP string) { c.applyManualOverrides() } +// mergeScreenshotsFromResult extracts screenshot metadata from a scan result +// and merges any matching the target host into the correlation. +func (c *Correlator) mergeScreenshotsFromResult(correlation *CorrelationResult, result *ScanResult, hostIP string) { + screenshotsData, ok := result.Metadata["screenshots"] + if !ok { + return + } + + var screenshots []ScreenshotInfo + switch v := screenshotsData.(type) { + case []ScreenshotInfo: + screenshots = v + case []map[string]string: + // Produced by MergeScreenshotsIntoCorrelation on a second pass + for _, ss := range v { + screenshots = append(screenshots, ScreenshotInfo{ + IP: hostIP, + URL: ss["url"], + File: ss["file"], + StatusCode: ss["status_code"], + }) + } + case []interface{}: + for _, item := range v { + switch s := item.(type) { + case ScreenshotInfo: + screenshots = append(screenshots, s) + case map[string]interface{}: + ss := ScreenshotInfo{ + IP: getStringFromMap(s, "ip"), + URL: getStringFromMap(s, "url"), + File: getStringFromMap(s, "file"), + StatusCode: getStringFromMap(s, "status_code"), + } + screenshots = append(screenshots, ss) + } + } + } + + var hostScreenshots []ScreenshotInfo + for _, ss := range screenshots { + if ss.IP == hostIP { + hostScreenshots = append(hostScreenshots, ss) + } + } + + if len(hostScreenshots) > 0 { + MergeScreenshotsIntoCorrelation(correlation, hostScreenshots) + } +} + +// collectServices merges services from a scan result into the correlation, +// deduplicating by host+port+protocol. +func (c *Correlator) collectServices(correlation *CorrelationResult, result *ScanResult, hostIP string, seenServices map[string]int) { + for _, service := range result.Services { + if service.Host != hostIP { + continue + } + key := service.Host + "|" + strconv.Itoa(service.Port) + "|" + service.Protocol + if idx, exists := seenServices[key]; exists { + c.mergeService(&correlation.Services[idx], service) + continue + } + seenServices[key] = len(correlation.Services) + correlation.Services = append(correlation.Services, service) + } +} + +// collectVulnerabilities deduplicates and appends vulnerabilities from a scan +// result into the correlation. +func (c *Correlator) collectVulnerabilities(correlation *CorrelationResult, result *ScanResult, hostIP string, seenVulns map[string]bool) { + for _, vuln := range result.Vulnerabilities { + if vuln.Host != hostIP { + continue + } + key := vuln.Host + "|" + vuln.Title + "|" + vuln.Source + "|" + strconv.Itoa(vuln.Port) + if seenVulns[key] { + continue + } + seenVulns[key] = true + correlation.Vulnerabilities = append(correlation.Vulnerabilities, vuln) + } +} + // resultContainsHost checks if a scan result contains information about a host func (c *Correlator) resultContainsHost(result *ScanResult, hostIP string) bool { // Check targets @@ -520,7 +550,7 @@ func (c *Correlator) mergePorts(existing []Port, new []Port) []Port { if port.Banner != "" && existingPort.Banner == "" { existingPort.Banner = port.Banner } - if port.State == "open" { + if port.State == portStatusOpen { existingPort.State = port.State } portMap[key] = existingPort @@ -571,29 +601,52 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea var breakdown RiskBreakdown factors := make([]RiskFactorDetail, 0) - // --- Vulnerability factor (max 500) --- - // sslscan and testssl findings are handled separately in the SSL factor below. - vulnScore := 0 + vulnScore := c.scoreVulnerabilities(correlation, &factors) + breakdown.VulnerabilityScore = vulnScore + + sslScore := c.scoreSSLIssues(correlation, &factors) + breakdown.SSLIssues = sslScore + + svcScore := c.scoreServiceExposure(correlation, &factors) + breakdown.ServiceExposure = svcScore + + portScore := c.scoreOpenPorts(correlation, &factors) + breakdown.OpenPortScore = portScore + + total := vulnScore + sslScore + svcScore + portScore + if total > 1000 { + total = 1000 + } + breakdown.Total = total + breakdown.Factors = factors + + return breakdown +} + +// scoreVulnerabilities computes the vulnerability sub-score (max 500) from +// non-SSL scan findings. +func (c *Correlator) scoreVulnerabilities(correlation *CorrelationResult, factors *[]RiskFactorDetail) int { + score := 0 criticalCount := 0 highCount := 0 for _, vuln := range correlation.Vulnerabilities { - if vuln.Source == "sslscan" || vuln.Source == "testssl" { + if vuln.Source == scanTypeSSLScan || vuln.Source == string(ScanTypeTestSSL) { continue } var pts int sev := strings.ToLower(vuln.Severity) switch sev { - case "critical": + case severityCritical: if criticalCount < 2 { pts = 150 } criticalCount++ - case "high": + case severityHigh: if highCount < 4 { pts = 80 } highCount++ - case "medium": + case severityMedium: pts = 40 case "low": pts = 15 @@ -601,8 +654,8 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea pts = 5 } if pts > 0 { - vulnScore += pts - factors = append(factors, RiskFactorDetail{ + score += pts + *factors = append(*factors, RiskFactorDetail{ Category: "vulnerability", Title: vuln.Title, Score: pts, @@ -611,30 +664,33 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea }) } } - if vulnScore > 500 { - vulnScore = 500 + if score > 500 { + return 500 } - breakdown.VulnerabilityScore = vulnScore + return score +} - // --- SSL/TLS factor (max 200) --- - slScore := 0 +// scoreSSLIssues computes the SSL/TLS sub-score (max 200) from sslscan and +// testssl findings. +func (c *Correlator) scoreSSLIssues(correlation *CorrelationResult, factors *[]RiskFactorDetail) int { + score := 0 for _, vuln := range correlation.Vulnerabilities { - if vuln.Source != "sslscan" && vuln.Source != "testssl" { + if vuln.Source != scanTypeSSLScan && vuln.Source != string(ScanTypeTestSSL) { continue } var pts int sev := strings.ToLower(vuln.Severity) switch sev { - case "critical": - pts = 100 // SSLv2/SSLv3 - case "high": - pts = 50 // weak cipher - case "medium": - pts = 30 // TLS 1.0/1.1 or cert issue + case severityCritical: + pts = 100 + case severityHigh: + pts = 50 + case severityMedium: + pts = 30 } if pts > 0 { - slScore += pts - factors = append(factors, RiskFactorDetail{ + score += pts + *factors = append(*factors, RiskFactorDetail{ Category: "ssl", Title: vuln.Title, Score: pts, @@ -643,13 +699,16 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea }) } } - if slScore > 200 { - slScore = 200 + if score > 200 { + return 200 } - breakdown.SSLIssues = slScore + return score +} - // --- Service exposure factor (max 200) --- - svcScore := 0 +// scoreServiceExposure computes the service exposure sub-score (max 200) from +// exposed risky services. +func (c *Correlator) scoreServiceExposure(correlation *CorrelationResult, factors *[]RiskFactorDetail) int { + score := 0 serviceMap := make(map[string]bool) for _, svc := range correlation.Services { svcName := strings.ToLower(svc.Name) @@ -670,20 +729,20 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea pts = 40 } if pts > 0 { - svcScore += pts - factors = append(factors, RiskFactorDetail{ + score += pts + *factors = append(*factors, RiskFactorDetail{ Category: "service", Title: fmt.Sprintf("%s exposed (port %d)", svc.Name, svc.Port), Score: pts, - Severity: "medium", + Severity: severityMedium, Source: "service-scan", }) } } // http without https if serviceMap["http"] && !serviceMap["https"] { - svcScore += 30 - factors = append(factors, RiskFactorDetail{ + score += 30 + *factors = append(*factors, RiskFactorDetail{ Category: "service", Title: "HTTP without HTTPS", Score: 30, @@ -691,55 +750,47 @@ func (c *Correlator) calculateRiskScore(correlation *CorrelationResult) RiskBrea Source: "service-scan", }) } - if svcScore > 200 { - svcScore = 200 + if score > 200 { + return 200 } - breakdown.ServiceExposure = svcScore + return score +} - // --- Open port factor (max 100) --- +// scoreOpenPorts computes the open-port sub-score (max 100) based on how many +// ports are open on the host. +func (c *Correlator) scoreOpenPorts(correlation *CorrelationResult, factors *[]RiskFactorDetail) int { openCount := 0 if correlation.HostInfo != nil { for _, port := range correlation.HostInfo.Ports { - if port.State == "open" { + if port.State == portStatusOpen { openCount++ } } } - // Also count from services if openCount == 0 { openCount = len(correlation.Services) } - var portScore int + var score int switch { case openCount > 50: - portScore = 100 + score = 100 case openCount > 20: - portScore = 60 + score = 60 case openCount > 5: - portScore = 30 + score = 30 case openCount > 0: - portScore = 10 + score = 10 } - if portScore > 0 { - factors = append(factors, RiskFactorDetail{ + if score > 0 { + *factors = append(*factors, RiskFactorDetail{ Category: "port", Title: fmt.Sprintf("%d open ports", openCount), - Score: portScore, + Score: score, Severity: "low", Source: "port-scan", }) } - breakdown.OpenPortScore = portScore - - // --- Total (cap 1000) --- - total := vulnScore + slScore + svcScore + portScore - if total > 1000 { - total = 1000 - } - breakdown.Total = total - breakdown.Factors = factors - - return breakdown + return score } // generateRecommendations generates security recommendations based on findings @@ -754,13 +805,13 @@ func (c *Correlator) generateRecommendations(correlation *CorrelationResult) []s for _, vuln := range correlation.Vulnerabilities { sev := strings.ToLower(vuln.Severity) switch sev { - case "critical": - if vuln.Source == "sslscan" || vuln.Source == "testssl" { + case severityCritical: + if vuln.Source == scanTypeSSLScan || vuln.Source == string(ScanTypeTestSSL) { } else { criticalCount++ } - case "high": - if vuln.Source == "sslscan" || vuln.Source == "testssl" { + case severityHigh: + if vuln.Source == scanTypeSSLScan || vuln.Source == string(ScanTypeTestSSL) { } else { highCount++ } @@ -835,120 +886,151 @@ func (c *Correlator) inferHostSubtype(correlation *CorrelationResult) { return } - osMatch := attrs["os_match"] - osVal := correlation.HostInfo.OS - vendor := strings.ToLower(attrs["vendor"]) - sysDesc := strings.ToLower(attrs["sys_description"]) - var subtype string - switch cat { case "windows": - // Use os_match first (from nmap -O), then fall back to OS field - osLower := strings.ToLower(osMatch) - if osLower == "" { - osLower = strings.ToLower(osVal) - } - if strings.Contains(osLower, "server") { - subtype = "server" - } else if strings.Contains(osLower, "windows 10") || - strings.Contains(osLower, "windows 11") || - strings.Contains(osLower, "professional") || - strings.Contains(osLower, "windows 7") || - strings.Contains(osLower, "windows 8") || - strings.Contains(osLower, "windows xp") || - strings.Contains(osLower, "vista") || - strings.Contains(osLower, "workstation") { - subtype = "workstation" - } - // Domain controller: has LDAP ports - if subtype == "" { - hasLDAP := false - for _, p := range correlation.HostInfo.Ports { - if (p.Number == 389 || p.Number == 636 || p.Number == 3268 || p.Number == 3269) && p.State == "open" { - hasLDAP = true - break - } - } - if hasLDAP { - subtype = "domain controller" - } - } - // SQL Server: has port 1433 - if subtype == "" { - for _, p := range correlation.HostInfo.Ports { - if p.Number == 1433 && p.State == "open" { - subtype = "sql server" - break - } - } - } - + subtype = c.inferWindowsSubtype(attrs, correlation.HostInfo.OS, correlation.HostInfo.Ports) case "network_device": - // Check vendor first — SNMP data is authoritative - if vendor == "printer" || strings.Contains(sysDesc, "printer") || - strings.Contains(sysDesc, "laser") || strings.Contains(sysDesc, "inkjet") { - subtype = "printer" - } - // Check sys_description for known device types - if subtype == "" { - if strings.Contains(sysDesc, "switch") || strings.Contains(sysDesc, "nexus") { - subtype = "switch" - } else if strings.Contains(sysDesc, "asa") { - subtype = "firewall" - } else if strings.Contains(sysDesc, "router") { - subtype = "router" - } else if strings.Contains(sysDesc, "storage") || strings.Contains(sysDesc, "netapp") { - subtype = "storage" - } else if strings.Contains(vendor, "ubiquiti") { - // Ubiquiti devices with port 80/443 are typically switches/routers - subtype = "switch" - } - } - // OS-based inference for network devices - if subtype == "" && osVal != "" { - osLower := strings.ToLower(osVal) - if strings.Contains(osLower, "cisco") && (strings.Contains(osLower, "asa") || strings.Contains(osLower, "adaptive")) { - subtype = "firewall" - } else if strings.Contains(osLower, "cisco") && strings.Contains(osLower, "nexus") { - subtype = "switch" - } else if strings.Contains(osLower, "cisco") && strings.Contains(osLower, "router") { - subtype = "router" - } else if strings.Contains(osLower, "netapp") { - subtype = "storage" - } else if strings.Contains(osLower, "brocade") { - subtype = "switch" - } else if strings.Contains(osLower, "aruba") { - subtype = "switch" - } + subtype = c.inferNetworkDeviceSubtype(attrs, correlation.HostInfo.OS, correlation.HostInfo.Ports) + case "linux": + subtype = c.inferLinuxSubtype(attrs, correlation.HostInfo.OS) + } + + if subtype != "" { + attrs["host_subtype"] = subtype + } +} + +// inferWindowsSubtype classifies a Windows host as server, workstation, +// domain controller, or SQL server based on OS match and open ports. +func (c *Correlator) inferWindowsSubtype(attrs map[string]string, osVal string, ports []Port) string { + osLower := strings.ToLower(attrs["os_match"]) + if osLower == "" { + osLower = strings.ToLower(osVal) + } + if strings.Contains(osLower, "server") { + return "server" + } + if strings.Contains(osLower, "windows 10") || + strings.Contains(osLower, "windows 11") || + strings.Contains(osLower, "professional") || + strings.Contains(osLower, "windows 7") || + strings.Contains(osLower, "windows 8") || + strings.Contains(osLower, "windows xp") || + strings.Contains(osLower, "vista") || + strings.Contains(osLower, "workstation") { + return "workstation" + } + // Domain controller: has LDAP ports + for _, p := range ports { + if (p.Number == 389 || p.Number == 636 || p.Number == 3268 || p.Number == 3269) && p.State == portStatusOpen { + return "domain controller" } - // Port-based fallback only if nothing else worked - if subtype == "" { - for _, p := range correlation.HostInfo.Ports { - if p.Number == 9100 && p.State == "open" { - subtype = "printer" - break - } - } + } + // SQL Server: has port 1433 + for _, p := range ports { + if p.Number == 1433 && p.State == portStatusOpen { + return "sql server" } + } + return "" +} - case "linux": - osLower := strings.ToLower(osMatch) - if osLower == "" { - osLower = strings.ToLower(osVal) - } - if strings.Contains(osLower, "openwrt") || strings.Contains(osLower, "mikrotik") { - subtype = "router" - } else if strings.Contains(osLower, "netapp") { - subtype = "storage" - } else if strings.Contains(osLower, "embedded") { - subtype = "embedded" +// inferNetworkDeviceSubtype classifies a network device using vendor info, +// sys_description patterns, OS match, and port fallbacks. +func (c *Correlator) inferNetworkDeviceSubtype(attrs map[string]string, osVal string, ports []Port) string { + vendor := strings.ToLower(attrs["vendor"]) + sysDesc := strings.ToLower(attrs["sys_description"]) + + // Vendor/SNMP-based check + if vendor == "printer" || strings.Contains(sysDesc, "printer") || + strings.Contains(sysDesc, "laser") || strings.Contains(sysDesc, "inkjet") { + return "printer" + } + if s := c.inferDeviceBySysDesc(sysDesc, vendor); s != "" { + return s + } + if s := c.inferDeviceByOS(osVal); s != "" { + return s + } + return c.inferDeviceByPorts(ports) +} + +// inferDeviceBySysDesc matches known device types from sys_description and +// vendor strings. +func (c *Correlator) inferDeviceBySysDesc(sysDesc, vendor string) string { + if strings.Contains(sysDesc, "switch") || strings.Contains(sysDesc, "nexus") { + return subtypeSwitch + } + if strings.Contains(sysDesc, "asa") { + return "firewall" + } + if strings.Contains(sysDesc, "router") { + return "router" + } + if strings.Contains(sysDesc, "storage") || strings.Contains(sysDesc, "netapp") { + return "storage" + } + if strings.Contains(vendor, "ubiquiti") { + return subtypeSwitch + } + return "" +} + +// inferDeviceByOS matches known device types from the OS string. +func (c *Correlator) inferDeviceByOS(osVal string) string { + if osVal == "" { + return "" + } + osLower := strings.ToLower(osVal) + if strings.Contains(osLower, "cisco") && (strings.Contains(osLower, "asa") || strings.Contains(osLower, "adaptive")) { + return "firewall" + } + if strings.Contains(osLower, "cisco") && strings.Contains(osLower, "nexus") { + return subtypeSwitch + } + if strings.Contains(osLower, "cisco") && strings.Contains(osLower, "router") { + return "router" + } + if strings.Contains(osLower, "netapp") { + return "storage" + } + if strings.Contains(osLower, "brocade") { + return subtypeSwitch + } + if strings.Contains(osLower, "aruba") { + return subtypeSwitch + } + return "" +} + +// inferDeviceByPorts uses port-based fallback to identify device type (e.g. +// port 9100 → printer). +func (c *Correlator) inferDeviceByPorts(ports []Port) string { + for _, p := range ports { + if p.Number == 9100 && p.State == portStatusOpen { + return "printer" } } + return "" +} - if subtype != "" { - attrs["host_subtype"] = subtype +// inferLinuxSubtype classifies a Linux host based on OS match patterns. +func (c *Correlator) inferLinuxSubtype(attrs map[string]string, osVal string) string { + osLower := strings.ToLower(attrs["os_match"]) + if osLower == "" { + osLower = strings.ToLower(osVal) + } + if strings.Contains(osLower, "openwrt") || strings.Contains(osLower, "mikrotik") { + return "router" + } + if strings.Contains(osLower, "netapp") { + return "storage" + } + if strings.Contains(osLower, "embedded") { + return "embedded" } + return "" } // GetCorrelationForHost returns correlation results for a specific host @@ -1029,7 +1111,7 @@ func (c *Correlator) saveResults() error { } correlationDir := filepath.Join(c.dataDir, "correlations") - if err := os.MkdirAll(correlationDir, 0755); err != nil { + if err := os.MkdirAll(correlationDir, 0750); err != nil { return fmt.Errorf("failed to create correlation directory: %w", err) } @@ -1040,7 +1122,7 @@ func (c *Correlator) saveResults() error { return fmt.Errorf("failed to marshal correlations: %w", err) } - if err := os.WriteFile(correlationFile, data, 0644); err != nil { + if err := os.WriteFile(correlationFile, data, 0600); err != nil { return fmt.Errorf("failed to write correlations: %w", err) } @@ -1116,7 +1198,7 @@ func (c *Correlator) loadManualOverrides() error { if path == "" { return nil } - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // G304: path from trusted workspace if os.IsNotExist(err) { return nil } @@ -1131,14 +1213,14 @@ func (c *Correlator) saveManualOverrides() error { if path == "" { return nil } - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { return fmt.Errorf("creating correlations directory: %w", err) } data, err := json.MarshalIndent(c.manualOverrides, "", " ") if err != nil { return fmt.Errorf("marshalling manual overrides: %w", err) } - if err := os.WriteFile(path, data, 0644); err != nil { + if err := os.WriteFile(path, data, 0600); err != nil { return err } c.fixCorrelationsOwnership() @@ -1159,7 +1241,7 @@ func (c *Correlator) loadExcludedHosts() error { if path == "" { return nil } - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // G304: path from trusted workspace if os.IsNotExist(err) { return nil } @@ -1182,14 +1264,14 @@ func (c *Correlator) saveExcludedHosts() error { if path == "" { return nil } - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { return fmt.Errorf("creating correlations dir: %w", err) } data, err := json.MarshalIndent(c.excludedHosts, "", " ") if err != nil { return fmt.Errorf("marshalling excluded hosts: %w", err) } - if err := os.WriteFile(path, data, 0644); err != nil { + if err := os.WriteFile(path, data, 0600); err != nil { return err } c.fixCorrelationsOwnership() @@ -1255,7 +1337,7 @@ func (c *Correlator) LoadResults() error { return nil } - data, err := os.ReadFile(correlationFile) + data, err := os.ReadFile(correlationFile) //nolint:gosec // G304: trusted workspace path if err != nil { return fmt.Errorf("failed to read correlations: %w", err) } @@ -1281,7 +1363,7 @@ func (c *Correlator) LoadResults() error { // ParseNmapOutput parses nmap output and creates a scan result func ParseNmapOutput(filePath, scanID string) (*ScanResult, error) { - content, err := os.ReadFile(filePath) + content, err := os.ReadFile(filePath) //nolint:gosec // G304: filePath from scan results if err != nil { return nil, err } @@ -1334,7 +1416,7 @@ func ParseNmapOutput(filePath, scanID string) (*ScanResult, error) { currentHost.Ports = append(currentHost.Ports, port) // Add to services - if port.State == "open" { + if port.State == portStatusOpen { service := Service{ Host: currentHost.IP, Port: portNum, diff --git a/internal/correlation/correlator_test.go b/internal/correlation/correlator_test.go index c5339fc..68652dc 100644 --- a/internal/correlation/correlator_test.go +++ b/internal/correlation/correlator_test.go @@ -889,7 +889,7 @@ PORT STATE SERVICE Nmap done: 2 IP addresses (2 hosts up) scanned in 1.23 seconds ` - if err := os.WriteFile(nmapFile, []byte(nmapOutput), 0644); err != nil { + if err := os.WriteFile(nmapFile, []byte(nmapOutput), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } @@ -951,7 +951,7 @@ func TestSetManualCategory(t *testing.T) { // Override file written overridePath := filepath.Join(dir, "correlations", "manual_categories.json") - data, err := os.ReadFile(overridePath) + data, err := os.ReadFile(overridePath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading override file: %v", err) } @@ -970,7 +970,7 @@ func TestManualOverrideSurvivesReload(t *testing.T) { // Write correlations.json with "unknown" category directly (bypassing SetManualCategory // so correlations.json does NOT contain the override — only manual_categories.json will). corrDir := filepath.Join(dir, "correlations") - if err := os.MkdirAll(corrDir, 0755); err != nil { + if err := os.MkdirAll(corrDir, 0750); err != nil { t.Fatalf("mkdir: %v", err) } correlations := map[string]*CorrelationResult{ @@ -983,14 +983,14 @@ func TestManualOverrideSurvivesReload(t *testing.T) { }, } corrData, _ := json.Marshal(correlations) - if err := os.WriteFile(filepath.Join(corrDir, "correlations.json"), corrData, 0644); err != nil { + if err := os.WriteFile(filepath.Join(corrDir, "correlations.json"), corrData, 0600); err != nil { t.Fatalf("writing correlations.json: %v", err) } // Write manual_categories.json with the override directly. overrides := map[string]string{"10.0.0.2": "network_device"} overrideData, _ := json.MarshalIndent(overrides, "", " ") - if err := os.WriteFile(filepath.Join(corrDir, "manual_categories.json"), overrideData, 0644); err != nil { + if err := os.WriteFile(filepath.Join(corrDir, "manual_categories.json"), overrideData, 0600); err != nil { t.Fatalf("writing manual_categories.json: %v", err) } @@ -1105,7 +1105,7 @@ Nmap done: 1 IP address (1 host up) scanned` tempDir := t.TempDir() nmapFile := filepath.Join(tempDir, "scan.nmap") - if err := os.WriteFile(nmapFile, []byte(input), 0644); err != nil { + if err := os.WriteFile(nmapFile, []byte(input), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } diff --git a/internal/correlation/hostfiles.go b/internal/correlation/hostfiles.go index 4a47184..11a925b 100644 --- a/internal/correlation/hostfiles.go +++ b/internal/correlation/hostfiles.go @@ -103,11 +103,11 @@ func moveHostInSession(hostfilesDir, ip, targetPlainFile, category string) error } target := filepath.Join(hostfilesDir, targetPlainFile) - f, err := os.OpenFile(target, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(target, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) //nolint:gosec // G304: path from trusted workspace if err != nil { return fmt.Errorf("opening %s: %w", target, err) } - defer f.Close() + defer func() { _ = f.Close() }() if _, err := fmt.Fprintln(f, ip); err != nil { return err } @@ -119,11 +119,11 @@ func moveHostInSession(hostfilesDir, ip, targetPlainFile, category string) error return fmt.Errorf("expected .txt suffix in plain file %q", targetPlainFile) } targetEnriched := filepath.Join(hostfilesDir, base+"_enriched.txt") - f2, err := os.OpenFile(targetEnriched, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f2, err := os.OpenFile(targetEnriched, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) //nolint:gosec // G304: path from trusted workspace if err != nil { return fmt.Errorf("opening %s: %w", targetEnriched, err) } - defer f2.Close() + defer func() { _ = f2.Close() }() // Update the category field in the enriched data // Format: "IP HOSTNAME CATEGORY [tags]" @@ -145,7 +145,7 @@ func moveHostInSession(hostfilesDir, ip, targetPlainFile, category string) error // Returns (true, nil) if any line was removed, (false, nil) if ip was not found, // (false, err) on I/O error. Missing files return (false, nil). func removeIPFromFile(path, ip string) (bool, error) { - f, err := os.Open(path) + f, err := os.Open(path) //nolint:gosec // G304: path from trusted workspace if os.IsNotExist(err) { return false, nil } @@ -168,7 +168,7 @@ func removeIPFromFile(path, ip string) (bool, error) { } keep = append(keep, line) } - f.Close() + _ = f.Close() if err := scanner.Err(); err != nil { return false, fmt.Errorf("scanning %s: %w", path, err) @@ -182,17 +182,17 @@ func removeIPFromFile(path, ip string) (bool, error) { if len(keep) > 0 { out += "\n" } - return true, os.WriteFile(path, []byte(out), 0644) + return true, os.WriteFile(path, []byte(out), 0600) } // extractEnrichedDataForIP reads an enriched file and returns the line for a specific IP. // Returns the full line (with hostname, category, tags) or empty string if not found. func extractEnrichedDataForIP(filepath, ip string) string { - f, err := os.Open(filepath) + f, err := os.Open(filepath) //nolint:gosec // G304: path from trusted workspace if err != nil { return "" } - defer f.Close() + defer func() { _ = f.Close() }() scanner := bufio.NewScanner(f) for scanner.Scan() { diff --git a/internal/correlation/hostfiles_test.go b/internal/correlation/hostfiles_test.go index c44c75a..5633002 100644 --- a/internal/correlation/hostfiles_test.go +++ b/internal/correlation/hostfiles_test.go @@ -10,11 +10,11 @@ import ( func makeSession(t *testing.T, discoveryDir, sessionName string, files map[string]string) string { t.Helper() hostfilesDir := filepath.Join(discoveryDir, sessionName, "hostfiles") - if err := os.MkdirAll(hostfilesDir, 0755); err != nil { + if err := os.MkdirAll(hostfilesDir, 0750); err != nil { t.Fatalf("makeSession: %v", err) } for name, content := range files { - if err := os.WriteFile(filepath.Join(hostfilesDir, name), []byte(content), 0644); err != nil { + if err := os.WriteFile(filepath.Join(hostfilesDir, name), []byte(content), 0600); err != nil { t.Fatalf("makeSession WriteFile: %v", err) } } @@ -23,7 +23,7 @@ func makeSession(t *testing.T, discoveryDir, sessionName string, files map[strin func readFile(t *testing.T, path string) string { t.Helper() - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // G304: test path if os.IsNotExist(err) { return "" } @@ -167,11 +167,11 @@ func TestMoveHostInHostfiles_UnknownCategory(t *testing.T) { func makeNestedSession(t *testing.T, discoveryDir, subpath string, files map[string]string) string { t.Helper() hostfilesDir := filepath.Join(discoveryDir, subpath, "hostfiles") - if err := os.MkdirAll(hostfilesDir, 0755); err != nil { + if err := os.MkdirAll(hostfilesDir, 0750); err != nil { t.Fatalf("makeNestedSession: %v", err) } for name, content := range files { - if err := os.WriteFile(filepath.Join(hostfilesDir, name), []byte(content), 0644); err != nil { + if err := os.WriteFile(filepath.Join(hostfilesDir, name), []byte(content), 0600); err != nil { t.Fatalf("makeNestedSession WriteFile: %v", err) } } diff --git a/internal/correlation/mac_table.go b/internal/correlation/mac_table.go index abea3a5..6d8ca0f 100644 --- a/internal/correlation/mac_table.go +++ b/internal/correlation/mac_table.go @@ -89,7 +89,7 @@ func normalizeMAC(raw string) string { return "" } for _, c := range stripped { - if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + if (c < '0' || c > '9') && (c < 'a' || c > 'f') { return "" } } diff --git a/internal/correlation/package.go b/internal/correlation/package.go index e340768..494e4a3 100644 --- a/internal/correlation/package.go +++ b/internal/correlation/package.go @@ -14,6 +14,8 @@ import ( "time" ) +const portStatusOpen = "open" + // CategoryFileNames maps category keys to their markdown output filenames. var categoryFileNames = map[string]string{ "windows": "windows.md", @@ -43,36 +45,18 @@ type HostEntry struct { // markdown hostlists and screenshot files. The archive is written to // {workspaceDir}/discovery/hostfile_distribution_YYYYMMDD_HHMMSS.tar.gz. func (c *Correlator) GenerateDistributionPackage() (string, error) { - correlations := c.GetAllCorrelations() - - // Build host entries grouped by category. - categoryEntries := make(map[string][]HostEntry) - for ip, corr := range correlations { - cat := hostCategoryFromResult(corr) - if cat == "unknown" || cat == "" { - continue - } - entry := buildHostEntry(ip, corr) - categoryEntries[cat] = append(categoryEntries[cat], entry) - } + categoryEntries := c.categorizeHosts() if len(categoryEntries) == 0 { return "", fmt.Errorf("no categorized hosts to package") } - // Sort entries within each category by IP. - for cat := range categoryEntries { - sort.Slice(categoryEntries[cat], func(i, j int) bool { - return compareIPsNumeric(categoryEntries[cat][i].IP, categoryEntries[cat][j].IP) - }) - } - timestamp := time.Now() ts := timestamp.Format("20060102_150405") archiveName := fmt.Sprintf("hostfile_distribution_%s.tar.gz", ts) discoveryDir := filepath.Join(c.workspaceDir, "discovery") - if err := os.MkdirAll(discoveryDir, 0755); err != nil { + if err := os.MkdirAll(discoveryDir, 0750); err != nil { return "", fmt.Errorf("creating discovery directory: %w", err) } archivePath := filepath.Join(discoveryDir, archiveName) @@ -95,7 +79,50 @@ func (c *Correlator) GenerateDistributionPackage() (string, error) { } // Copy screenshot files. - screenshotsDir := filepath.Join(tmpDir, "screenshots") + if err := copyScreenshotsToDir(tmpDir, categoryEntries); err != nil { + return "", err + } + + // Write metadata. + if err := writeMetadata(filepath.Join(tmpDir, "metadata.txt"), timestamp, categoryEntries); err != nil { + return "", fmt.Errorf("writing metadata: %w", err) + } + + // Create archive. + if err := createTarGz(tmpDir, archivePath); err != nil { + return "", fmt.Errorf("creating archive: %w", err) + } + + return archivePath, nil +} + +// categorizeHosts builds host entries grouped by category from all correlations, +// sorted by IP within each category. +func (c *Correlator) categorizeHosts() map[string][]HostEntry { + correlations := c.GetAllCorrelations() + categoryEntries := make(map[string][]HostEntry) + for ip, corr := range correlations { + cat := hostCategoryFromResult(corr) + if cat == "unknown" || cat == "" { + continue + } + entry := buildHostEntry(ip, corr) + categoryEntries[cat] = append(categoryEntries[cat], entry) + } + + for cat := range categoryEntries { + sort.Slice(categoryEntries[cat], func(i, j int) bool { + return compareIPsNumeric(categoryEntries[cat][i].IP, categoryEntries[cat][j].IP) + }) + } + + return categoryEntries +} + +// copyScreenshotsToDir copies screenshot files from host entries into a +// screenshots subdirectory within tmpDir. Missing screenshots produce warnings +// but do not abort the operation. +func copyScreenshotsToDir(tmpDir string, categoryEntries map[string][]HostEntry) error { hasScreenshots := false for _, entries := range categoryEntries { for _, e := range entries { @@ -108,35 +135,27 @@ func (c *Correlator) GenerateDistributionPackage() (string, error) { break } } - if hasScreenshots { - if err := os.MkdirAll(screenshotsDir, 0755); err != nil { - return "", fmt.Errorf("creating screenshots directory: %w", err) - } - for _, entries := range categoryEntries { - for _, e := range entries { - for _, ss := range e.ScreenshotFiles { - destName := filepath.Base(ss.File) - destPath := filepath.Join(screenshotsDir, destName) - if err := copyFile(ss.File, destPath); err != nil { - // Best-effort: missing screenshots should not abort the package. - _, _ = fmt.Fprintf(os.Stderr, "warning: skipping screenshot %s: %v\n", ss.File, err) - } - } - } - } + if !hasScreenshots { + return nil } - // Write metadata. - if err := writeMetadata(filepath.Join(tmpDir, "metadata.txt"), timestamp, categoryEntries); err != nil { - return "", fmt.Errorf("writing metadata: %w", err) + screenshotsDir := filepath.Join(tmpDir, "screenshots") + if err := os.MkdirAll(screenshotsDir, 0750); err != nil { + return fmt.Errorf("creating screenshots directory: %w", err) } - - // Create archive. - if err := createTarGz(tmpDir, archivePath); err != nil { - return "", fmt.Errorf("creating archive: %w", err) + for _, entries := range categoryEntries { + for _, e := range entries { + for _, ss := range e.ScreenshotFiles { + destName := filepath.Base(ss.File) + destPath := filepath.Join(screenshotsDir, destName) + if err := copyFile(ss.File, destPath); err != nil { + // Best-effort: missing screenshots should not abort the package. + _, _ = fmt.Fprintf(os.Stderr, "warning: skipping screenshot %s: %v\n", ss.File, err) + } + } + } } - - return archivePath, nil + return nil } // hostCategoryFromResult returns the category from HostInfo.Attributes, @@ -172,7 +191,7 @@ func buildHostEntry(ip string, corr *CorrelationResult) HostEntry { var portNums []int for _, p := range corr.HostInfo.Ports { - if p.State == "open" { + if p.State == portStatusOpen { portNums = append(portNums, p.Number) } } @@ -209,7 +228,7 @@ func formatScreenshotNotes(screenshots []ScreenshotInfo) string { // writeMarkdownFile writes a markdown file with a table of host entries. func writeMarkdownFile(path, category string, entries []HostEntry) error { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } @@ -256,7 +275,7 @@ func writeMarkdownFile(path, category string, entries []HostEntry) error { // writeMetadata creates a metadata.txt with timestamp and host counts. func writeMetadata(path string, ts time.Time, categoryEntries map[string][]HostEntry) error { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } @@ -278,16 +297,16 @@ func writeMetadata(path string, ts time.Time, categoryEntries map[string][]HostE // copyFile copies src to dst. Creates dst parent directories if needed. func copyFile(src, dst string) error { - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(dst), 0750); err != nil { return err } - in, err := os.Open(src) + in, err := os.Open(src) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } defer func() { _ = in.Close() }() - out, err := os.Create(dst) + out, err := os.Create(dst) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } @@ -299,7 +318,7 @@ func copyFile(src, dst string) error { // createTarGz creates a tar.gz archive of the contents of srcDir at dstPath. func createTarGz(srcDir, dstPath string) error { - out, err := os.Create(dstPath) + out, err := os.Create(dstPath) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } @@ -343,7 +362,7 @@ func createTarGz(srcDir, dstPath string) error { } func copyFileToTar(tw *tar.Writer, filePath string) error { - f, err := os.Open(filePath) + f, err := os.Open(filePath) //nolint:gosec // G304: path from trusted workspace if err != nil { return err } diff --git a/internal/correlation/package_test.go b/internal/correlation/package_test.go index 81a0cef..464f26f 100644 --- a/internal/correlation/package_test.go +++ b/internal/correlation/package_test.go @@ -10,6 +10,7 @@ import ( "testing" ) +const maxDecompressionSize = 100 * 1024 * 1024 // 100 MB func TestFormatScreenshotNotes(t *testing.T) { tests := []struct { name string @@ -262,7 +263,7 @@ func TestWriteMarkdownFile(t *testing.T) { t.Fatalf("writeMarkdownFile() error: %v", err) } - data, err := os.ReadFile(path) + data, err := os.ReadFile(path) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading output: %v", err) } @@ -295,6 +296,42 @@ func TestWriteMarkdownFile(t *testing.T) { } } +// readTarEntries reads a .tar.gz archive at archivePath and returns a map of +// entry name to file content. It fatals the test on any I/O error. +func readTarEntries(t *testing.T, archivePath string) map[string]string { + t.Helper() + entries := make(map[string]string) + + f, err := os.Open(archivePath) //nolint:gosec // G304: test path + if err != nil { + t.Fatalf("opening archive: %v", err) + } + defer func() { _ = f.Close() }() + + gz, err := gzip.NewReader(f) + if err != nil { + t.Fatalf("gzip reader: %v", err) + } + defer func() { _ = gz.Close() }() + + tr := tar.NewReader(gz) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("reading tar: %v", err) + } + var buf strings.Builder + if _, err := io.Copy(&buf, io.LimitReader(tr, maxDecompressionSize)); err != nil { + t.Fatalf("reading %s: %v", hdr.Name, err) + } + entries[hdr.Name] = buf.String() + } + return entries +} + func TestGenerateDistributionPackage(t *testing.T) { tmpDir := t.TempDir() dataDir := t.TempDir() @@ -362,93 +399,49 @@ func TestGenerateDistributionPackage(t *testing.T) { } // Verify archive contents. - f, err := os.Open(archivePath) - if err != nil { - t.Fatalf("opening archive: %v", err) - } - defer func() { _ = f.Close() }() + entries := readTarEntries(t, archivePath) - gz, err := gzip.NewReader(f) - if err != nil { - t.Fatalf("gzip reader: %v", err) - } - defer func() { _ = gz.Close() }() - - tr := tar.NewReader(gz) - foundFiles := make(map[string]bool) - for { - hdr, err := tr.Next() - if err == io.EOF { - break + if content, ok := entries["windows.md"]; !ok { + t.Error("archive missing windows.md") + } else { + if !strings.Contains(content, "192.168.1.10") { + t.Error("windows.md missing windows host IP") } - if err != nil { - t.Fatalf("reading tar: %v", err) + if !strings.Contains(content, "DESKTOP-ABC") { + t.Error("windows.md missing hostname") } - foundFiles[hdr.Name] = true - - // Verify windows.md content. - if hdr.Name == "windows.md" { - var buf strings.Builder - if _, err := io.Copy(&buf, tr); err != nil { - t.Fatalf("reading windows.md: %v", err) - } - content := buf.String() - if !strings.Contains(content, "192.168.1.10") { - t.Error("windows.md missing windows host IP") - } - if !strings.Contains(content, "DESKTOP-ABC") { - t.Error("windows.md missing hostname") - } - if strings.Contains(content, "192.168.1.20") { - t.Error("windows.md should not contain linux host") - } + if strings.Contains(content, "192.168.1.20") { + t.Error("windows.md should not contain linux host") } + } - // Verify linux.md content. - if hdr.Name == "linux.md" { - var buf strings.Builder - if _, err := io.Copy(&buf, tr); err != nil { - t.Fatalf("reading linux.md: %v", err) - } - content := buf.String() - if !strings.Contains(content, "192.168.1.20") { - t.Error("linux.md missing linux host IP") - } - if !strings.Contains(content, "webserver") { - t.Error("linux.md missing hostname") - } + if content, ok := entries["linux.md"]; !ok { + t.Error("archive missing linux.md") + } else { + if !strings.Contains(content, "192.168.1.20") { + t.Error("linux.md missing linux host IP") } - - // Verify metadata.txt. - if hdr.Name == "metadata.txt" { - var buf strings.Builder - if _, err := io.Copy(&buf, tr); err != nil { - t.Fatalf("reading metadata.txt: %v", err) - } - content := buf.String() - if !strings.Contains(content, "windows: 1") { - t.Error("metadata missing windows count") - } - if !strings.Contains(content, "linux: 1") { - t.Error("metadata missing linux count") - } - if !strings.Contains(content, "total: 2") { - t.Error("metadata missing total count") - } + if !strings.Contains(content, "webserver") { + t.Error("linux.md missing hostname") } } - if !foundFiles["windows.md"] { - t.Error("archive missing windows.md") - } - if !foundFiles["linux.md"] { - t.Error("archive missing linux.md") - } - if !foundFiles["metadata.txt"] { + if content, ok := entries["metadata.txt"]; !ok { t.Error("archive missing metadata.txt") + } else { + if !strings.Contains(content, "windows: 1") { + t.Error("metadata missing windows count") + } + if !strings.Contains(content, "linux: 1") { + t.Error("metadata missing linux count") + } + if !strings.Contains(content, "total: 2") { + t.Error("metadata missing total count") + } } + // Unknown hosts should not produce a file. - if foundFiles["unknown.md"] { + if _, ok := entries["unknown.md"]; ok { t.Error("archive should not contain unknown.md") } } @@ -459,11 +452,11 @@ func TestGenerateDistributionPackageWithScreenshots(t *testing.T) { // Create a fake screenshot file. screenshotDir := filepath.Join(tmpDir, "captures", "screenshots") - if err := os.MkdirAll(screenshotDir, 0755); err != nil { + if err := os.MkdirAll(screenshotDir, 0750); err != nil { t.Fatal(err) } screenshotPath := filepath.Join(screenshotDir, "http--192.168.1.10-80.jpeg") - if err := os.WriteFile(screenshotPath, []byte("fake png data"), 0644); err != nil { + if err := os.WriteFile(screenshotPath, []byte("fake png data"), 0600); err != nil { t.Fatal(err) } @@ -495,7 +488,7 @@ func TestGenerateDistributionPackageWithScreenshots(t *testing.T) { } // Verify archive contains screenshots directory and the file. - f, err := os.Open(archivePath) + f, err := os.Open(archivePath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("opening archive: %v", err) } @@ -529,7 +522,7 @@ func TestGenerateDistributionPackageWithScreenshots(t *testing.T) { // Verify markdown contains screenshot wikilinks. if hdr.Name == "windows.md" { var buf strings.Builder - if _, err := io.Copy(&buf, tr); err != nil { + if _, err := io.Copy(&buf, io.LimitReader(tr, maxDecompressionSize)); err != nil { t.Fatalf("reading windows.md: %v", err) } content := buf.String() diff --git a/internal/correlation/parsers.go b/internal/correlation/parsers.go index 5bf23f9..e8617cb 100644 --- a/internal/correlation/parsers.go +++ b/internal/correlation/parsers.go @@ -94,69 +94,91 @@ func (rp *ResultParser) determineScanType(scriptPath, outputContent string) Scan return ScanTypeHostCategorization } - // Nmap XML detection — check before generic name patterns + if t := rp.detectScanTypeByContent(contentLower, scriptName, outputContent); t != ScanTypePortScan { + return t + } + if t := rp.detectScanTypeByScriptName(scriptName, scriptPath); t != ScanTypePortScan { + return t + } + return ScanTypePortScan // Default fallback +} + +// detectScanTypeByContent inspects output content to identify the scan type. +func (rp *ResultParser) detectScanTypeByContent(contentLower, scriptName, outputContent string) ScanType { + // Nmap XML detection if strings.Contains(contentLower, " 4 { - portNum, _ := strconv.Atoi(matches[1]) - protocol := matches[2] - state := matches[3] - serviceInfo := strings.TrimSpace(matches[4]) - - port := Port{ - Number: portNum, - Protocol: protocol, - State: state, - } - - // Parse service information - serviceParts := strings.Fields(serviceInfo) - if len(serviceParts) > 0 { - port.Service = serviceParts[0] - if len(serviceParts) > 1 { - port.Version = strings.Join(serviceParts[1:], " ") - } - } - - currentHost.Ports = append(currentHost.Ports, port) - - // Add to services if open - if state == "open" { - service := Service{ - Host: currentHost.IP, - Port: portNum, - Protocol: protocol, - Name: port.Service, - Version: port.Version, - } - result.Services = append(result.Services, service) - } - } + rp.parsePortLine(result, currentHost, line) } // OS detection @@ -360,6 +331,49 @@ func (rp *ResultParser) parsePortScan(result *ScanResult, content string) (*Scan return result, nil } +// parsePortLine parses a single port line from nmap output and updates the host and result. +func (rp *ResultParser) parsePortLine(result *ScanResult, host *Host, line string) { + portRegex := regexp.MustCompile(`(\d+)/(tcp|udp)\s+(\w+)\s+(.*)`) + matches := portRegex.FindStringSubmatch(line) + if len(matches) <= 4 { + return + } + + portNum, _ := strconv.Atoi(matches[1]) + protocol := matches[2] + state := matches[3] + serviceInfo := strings.TrimSpace(matches[4]) + + port := Port{ + Number: portNum, + Protocol: protocol, + State: state, + } + + // Parse service information + serviceParts := strings.Fields(serviceInfo) + if len(serviceParts) > 0 { + port.Service = serviceParts[0] + if len(serviceParts) > 1 { + port.Version = strings.Join(serviceParts[1:], " ") + } + } + + host.Ports = append(host.Ports, port) + + // Add to services if open + if state == "open" { + service := Service{ + Host: host.IP, + Port: portNum, + Protocol: protocol, + Name: port.Service, + Version: port.Version, + } + result.Services = append(result.Services, service) + } +} + // parseVulnerabilityScan parses vulnerability scan output func (rp *ResultParser) parseVulnerabilityScan(result *ScanResult, content string) (*ScanResult, error) { lines := strings.Split(content, "\n") @@ -384,14 +398,14 @@ func (rp *ResultParser) parseVulnerabilityScan(result *ScanResult, content strin if strings.Contains(lineLower, "not vulnerable") || strings.Contains(lineLower, "no vulnerable") { continue } - severity := "medium" // Default severity + severity := severityMedium // Default severity title := line // Determine severity from keywords - if strings.Contains(lineLower, "critical") { - severity = "critical" + if strings.Contains(lineLower, severityCritical) { + severity = severityCritical } else if strings.Contains(lineLower, "high") { - severity = "high" + severity = severityHigh } else if strings.Contains(lineLower, "low") { severity = "low" } @@ -436,13 +450,9 @@ func (rp *ResultParser) parseNetworkCapture(result *ScanResult, content string) } } - // Parse protocols and services from packet captures - if strings.Contains(strings.ToLower(line), "http") { - // Extract HTTP traffic details - } - if strings.Contains(strings.ToLower(line), "ssh") { - // Extract SSH traffic details - } + // TODO: parse HTTP/SSH traffic details from packet captures + _ = strings.Contains(strings.ToLower(line), "http") + _ = strings.Contains(strings.ToLower(line), "ssh") } // Create host entries for discovered IPs @@ -561,7 +571,7 @@ func (rp *ResultParser) parseGenericOutput(result *ScanResult, content string) ( // ParseResultFile parses a result file and returns a scan result func (rp *ResultParser) ParseResultFile(filePath string) (*ScanResult, error) { - content, err := os.ReadFile(filePath) + content, err := os.ReadFile(filePath) //nolint:gosec // G304: path from trusted workspace if err != nil { return nil, fmt.Errorf("failed to read file: %w", err) } @@ -660,6 +670,19 @@ type niktoItem struct { Text string `xml:",chardata"` } +// niktoGroupKey uniquely identifies a group of related nikto findings. +type niktoGroupKey struct { + host string + port int + id string +} + +// niktoItemGroup holds a single item within a nikto finding group. +type niktoItemGroup struct { + title string + severity string +} + // --- Nmap XML types for parsing --- // nmapRun is the top-level XML structure for nmap XML output. @@ -798,79 +821,84 @@ func (rp *ResultParser) parseNmapXML(result *ScanResult, content string) (*ScanR } result.Targets = append(result.Targets, ip) - // Extract OS detection from nmap -O scan (osmatch). - // Store the best match as host OS and os_match attribute. - if h.OS != nil && len(h.OS.Matches) > 0 { - best := h.OS.Matches[0] - host.OS = best.Name - host.OSDetails = best.Name - host.Attributes["os_match"] = best.Name - if best.Accuracy != "" { - host.Attributes["os_match_accuracy"] = best.Accuracy - } - // Extract device type from osclass if available - for _, cls := range best.Classes { - if cls.Type != "" { - host.Attributes["device_type"] = cls.Type - break - } - } + extractNmapOSDetection(&host, h.OS) + rp.processNmapPorts(result, &host, ip, h.Ports) + + + result.Hosts = append(result.Hosts, host) + } + + return result, nil +} + +// extractNmapOSDetection extracts OS detection from nmap -O scan results +// and populates host attributes accordingly. +func extractNmapOSDetection(host *Host, os *nmapOS) { + if os == nil || len(os.Matches) == 0 { + return + } + best := os.Matches[0] + host.OS = best.Name + host.OSDetails = best.Name + host.Attributes["os_match"] = best.Name + if best.Accuracy != "" { + host.Attributes["os_match_accuracy"] = best.Accuracy + } + for _, cls := range best.Classes { + if cls.Type != "" { + host.Attributes["device_type"] = cls.Type + break } + } +} - if h.Ports == nil { - result.Hosts = append(result.Hosts, host) - continue +// processNmapPorts processes port entries from an nmap host, populating +// the host's port list, open services in the result, and script findings. +func (rp *ResultParser) processNmapPorts(result *ScanResult, host *Host, ip string, ports *nmapPorts) { + if ports == nil { + return + } + for _, p := range ports.Ports { + port := Port{ + Number: p.PortID, + Protocol: p.Protocol, + State: p.State.State, } + if p.Service != nil { + port.Service = p.Service.Name + port.Version = p.Service.Version + if p.Service.Product != "" { + port.Banner = p.Service.Product + if p.Service.Version != "" { + port.Banner += " " + p.Service.Version + } + } + } + host.Ports = append(host.Ports, port) - for _, p := range h.Ports.Ports { - port := Port{ - Number: p.PortID, + if p.State.State == "open" { + svc := Service{ + Host: ip, + Port: p.PortID, Protocol: p.Protocol, - State: p.State.State, + Name: port.Service, } if p.Service != nil { - port.Service = p.Service.Name - port.Version = p.Service.Version - if p.Service.Product != "" { - port.Banner = p.Service.Product - if p.Service.Version != "" { - port.Banner += " " + p.Service.Version - } - } + svc.Version = p.Service.Version + svc.Product = p.Service.Product + svc.ExtraInfo = p.Service.ExtraInfo + svc.Confidence = p.Service.Conf } - host.Ports = append(host.Ports, port) - - if p.State.State == "open" { - svc := Service{ - Host: ip, - Port: p.PortID, - Protocol: p.Protocol, - Name: port.Service, - } - if p.Service != nil { - svc.Version = p.Service.Version - svc.Product = p.Service.Product - svc.ExtraInfo = p.Service.ExtraInfo - svc.Confidence = p.Service.Conf - } - result.Services = append(result.Services, svc) - } - - // Process NSE scripts for vulnerabilities. - for _, script := range p.Scripts { - // Skip vulners script — it's a CPE database lookup, not findings. - if script.ID == "vulners" { - continue - } + result.Services = append(result.Services, svc) + } - rp.extractScriptFindings(result, ip, p.PortID, script) + for _, script := range p.Scripts { + if script.ID == "vulners" { + continue } + rp.extractScriptFindings(result, ip, p.PortID, script) } - - result.Hosts = append(result.Hosts, host) } - - return result, nil } // extractScriptFindings processes a single NSE script element and extracts @@ -891,9 +919,10 @@ func (rp *ResultParser) extractScriptFindings(result *ScanResult, hostIP string, rp.addVulnFromTable(result, hostIP, portNum, script.ID, child) } } + // No state at all — informational script (e.g., http-cookie-flags, http-server-header). + // Skip: these are not vulnerability findings. if !foundChild && state == "" { - // No state at all — informational script (e.g., http-cookie-flags, http-server-header). - // Skip: these are not vulnerability findings. + continue } continue } @@ -913,9 +942,9 @@ func (rp *ResultParser) addVulnFromTable(result *ScanResult, hostIP string, port title = scriptID } - severity := "medium" + severity := severityMedium if state == "VULNERABLE" { - severity = "high" + severity = severityHigh } // Extract CVE from IDs table. @@ -1006,24 +1035,7 @@ func splitXMLDocuments(content string) []string { // Nikto appends per-host XML to the same file, producing concatenated XML // documents. We split on ...). - var scans niktoScans - if err := xml.Unmarshal([]byte(doc), &scans); err == nil && len(scans.Scans) > 0 { - for _, s := range scans.Scans { - allDetails = append(allDetails, s.ScanDetails...) - } - continue - } - // Try singular root (...). - var single niktoScan - if err := xml.Unmarshal([]byte(doc), &single); err != nil { - continue - } - allDetails = append(allDetails, single.ScanDetails...) - } + allDetails := rp.parseNiktoScanDetails(content) if len(allDetails) == 0 { return rp.parseGenericOutput(result, content) @@ -1048,120 +1060,136 @@ func (rp *ResultParser) parseNiktoXMLResult(result *ScanResult, content string) result.Targets = append(result.Targets, ip) } - // Parse port if available port := 0 if details.TargetPort != "" { port, _ = strconv.Atoi(details.TargetPort) } - // First pass: collect items grouped by (host, port, itemID). - // Items with the same nikto ID (e.g., "013587" for missing security - // headers) represent the same class of finding and should be - // consolidated into one scored entry. - type itemGroup struct { - title string - severity string + groups := groupNiktoItems(details.Items, ip, port) + emitNiktoVulnerabilities(result, groups, result.Timestamp) + } + + return result, nil +} + +// parseNiktoScanDetails splits concatenated XML documents and unmarshals nikto scans. +func (rp *ResultParser) parseNiktoScanDetails(content string) []niktoScanDetails { + var allDetails []niktoScanDetails + for _, doc := range splitXMLDocuments(content) { + // Try plural wrapper first (...). + var scans niktoScans + if err := xml.Unmarshal([]byte(doc), &scans); err == nil && len(scans.Scans) > 0 { + for _, s := range scans.Scans { + allDetails = append(allDetails, s.ScanDetails...) + } + continue } - type groupKey struct { - host string - port int - id string + // Try singular root (...). + var single niktoScan + if err := xml.Unmarshal([]byte(doc), &single); err != nil { + continue } - groups := make(map[groupKey][]itemGroup) + allDetails = append(allDetails, single.ScanDetails...) + } + return allDetails +} - for _, item := range details.Items { - severity := niktoSeverity(item) +// groupNiktoItems collects items grouped by (host, port, itemID). +// Items with the same nikto ID represent the same class of finding and are +// consolidated into one scored entry. +func groupNiktoItems(items []niktoItem, ip string, port int) map[niktoGroupKey][]niktoItemGroup { + groups := make(map[niktoGroupKey][]niktoItemGroup) + for _, item := range items { + severity := niktoSeverity(item) - title := item.Description - if title == "" { - title = item.DescAttr - } - if title == "" { - title = item.Text - } - if title == "" && item.URL != "" { - title = fmt.Sprintf("%s %s", item.Method, item.URL) - } - if title == "" { - title = "Nikto finding" - } + title := item.Description + if title == "" { + title = item.DescAttr + } + if title == "" { + title = item.Text + } + if title == "" && item.URL != "" { + title = fmt.Sprintf("%s %s", item.Method, item.URL) + } + if title == "" { + title = "Nikto finding" + } - // Use a unique key for items without a nikto ID so they aren't - // incorrectly grouped with other items that also lack an ID. - itemID := item.ID - if itemID == "" { - itemID = "_" + strconv.Itoa(len(groups)) - } - key := groupKey{host: ip, port: port, id: itemID} - groups[key] = append(groups[key], itemGroup{ - title: strings.TrimSpace(title), - severity: severity, + // Use a unique key for items without a nikto ID so they aren't + // incorrectly grouped with other items that also lack an ID. + itemID := item.ID + if itemID == "" { + itemID = "_" + strconv.Itoa(len(groups)) + } + key := niktoGroupKey{host: ip, port: port, id: itemID} + groups[key] = append(groups[key], niktoItemGroup{ + title: strings.TrimSpace(title), + severity: severity, + }) + } + return groups +} + +// emitNiktoVulnerabilities consolidates grouped items into vulnerability entries. +// For multi-item groups, it produces a single entry with count and suffixes. +func emitNiktoVulnerabilities(result *ScanResult, groups map[niktoGroupKey][]niktoItemGroup, timestamp time.Time) { + for key, items := range groups { + if len(items) == 1 { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: key.host, + Port: key.port, + Title: items[0].title, + Severity: items[0].severity, + Source: "nikto", + Discovery: timestamp, }) + continue } - // Second pass: emit one vulnerability per group. - // For multi-item groups, consolidate into a single entry with count. - for key, items := range groups { - if len(items) == 1 { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: key.host, - Port: key.port, - Title: items[0].title, - Severity: items[0].severity, - Source: "nikto", - Discovery: result.Timestamp, - }) - continue - } + consolidated := consolidateNiktoGroup(items) - // Extract the common prefix from titles for the group title. - // e.g., "Suggested security header missing: referrer-policy" - // → prefix = "Suggested security header missing" - // → suffixes = ["referrer-policy", "x-content-type-options", ...] - prefix := commonPrefix(items[0].title, items[1].title) - for _, it := range items[2:] { - prefix = commonPrefix(prefix, it.title) - } - prefix = strings.TrimRight(prefix, ": ") - - // Collect the variable suffixes after the common prefix. - var suffixes []string - for _, it := range items { - suffix := strings.TrimPrefix(it.title, prefix) - suffix = strings.TrimLeft(suffix, ": ") - if suffix != "" { - suffixes = append(suffixes, suffix) - } + // Use the highest severity from the group. + sev := items[0].severity + for _, it := range items[1:] { + if severityRank(it.severity) > severityRank(sev) { + sev = it.severity } + } - // Build consolidated title. - consolidated := prefix - if len(suffixes) > 0 { - consolidated = fmt.Sprintf("%s (%d): %s", prefix, len(items), strings.Join(suffixes, ", ")) - } else { - consolidated = fmt.Sprintf("%s (%d instances)", prefix, len(items)) - } + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: key.host, + Port: key.port, + Title: consolidated, + Severity: sev, + Source: "nikto", + Discovery: timestamp, + }) + } +} - // Use the highest severity from the group. - sev := items[0].severity - for _, it := range items[1:] { - if severityRank(it.severity) > severityRank(sev) { - sev = it.severity - } - } +// consolidateNiktoGroup builds a consolidated title from a group of related items. +// It extracts the common prefix and lists variable suffixes. +func consolidateNiktoGroup(items []niktoItemGroup) string { + prefix := commonPrefix(items[0].title, items[1].title) + for _, it := range items[2:] { + prefix = commonPrefix(prefix, it.title) + } + prefix = strings.TrimRight(prefix, ": ") - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: key.host, - Port: key.port, - Title: consolidated, - Severity: sev, - Source: "nikto", - Discovery: result.Timestamp, - }) + var suffixes []string + for _, it := range items { + suffix := strings.TrimPrefix(it.title, prefix) + suffix = strings.TrimLeft(suffix, ": ") + if suffix != "" { + suffixes = append(suffixes, suffix) } } - return result, nil + if len(suffixes) > 0 { + return fmt.Sprintf("%s (%d): %s", prefix, len(items), strings.Join(suffixes, ", ")) + } + return fmt.Sprintf("%s (%d instances)", prefix, len(items)) } // niktoSeverity assigns a heuristic severity to a nikto finding based on the @@ -1177,7 +1205,7 @@ func niktoSeverity(item niktoItem) string { "remote code", "rce", "injection"} for _, kw := range highKW { if strings.Contains(text, kw) { - return "high" + return severityHigh } } @@ -1185,7 +1213,7 @@ func niktoSeverity(item niktoItem) string { "disabled", "enabled", "x-powered", "server:", "cookie"} for _, kw := range mediumKW { if strings.Contains(text, kw) { - return "medium" + return severityMedium } } @@ -1208,11 +1236,11 @@ func commonPrefix(a, b string) string { // severityRank returns a numeric rank for comparing severities. func severityRank(sev string) int { switch strings.ToLower(sev) { - case "critical": + case severityCritical: return 4 - case "high": + case severityHigh: return 3 - case "medium": + case severityMedium: return 2 case "low": return 1 @@ -1258,122 +1286,139 @@ func (rp *ResultParser) parseSSLScanResult(result *ScanResult, content string) ( continue } - lineLower := strings.ToLower(line) - - // Detect SSLv2/SSLv3 — critical - if (strings.Contains(lineLower, "sslv2") || strings.Contains(lineLower, "sslv3")) && - strings.Contains(lineLower, "enabled") { - proto := "SSLv3" - if strings.Contains(lineLower, "sslv2") { - proto = "SSLv2" + // Detect port from sslscan output + if strings.Contains(line, "SSL/TLS") && strings.Contains(line, ":") { + portRegex := regexp.MustCompile(`(\d+)`) + if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 { + currentPort, _ = strconv.Atoi(matches[1]) } - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: fmt.Sprintf("%s enabled", proto), - Description: line, - Severity: "critical", - Source: "sslscan", - Discovery: result.Timestamp, - }) } - // Detect TLS 1.0/1.1 — medium - if (strings.Contains(lineLower, "tlsv1.0") || strings.Contains(lineLower, "tlsv1.1") || - strings.Contains(lineLower, "tls 1.0") || strings.Contains(lineLower, "tls 1.1")) && - strings.Contains(lineLower, "enabled") { + rp.parseSSLScanProtocolLine(result, currentIP, currentPort, line) + rp.parseSSLScanCipherLine(result, currentIP, currentPort, line, seenDH) + rp.parseSSLScanCertLine(result, currentIP, currentPort, line) + } + + return result, nil +} + +// parseSSLScanProtocolLine detects SSLv2/SSLv3 and deprecated TLS protocol lines. +func (rp *ResultParser) parseSSLScanProtocolLine(result *ScanResult, ip string, port int, line string) { + lineLower := strings.ToLower(line) + + // Detect SSLv2/SSLv3 — critical + if (strings.Contains(lineLower, "sslv2") || strings.Contains(lineLower, "sslv3")) && + strings.Contains(lineLower, "enabled") { + proto := "SSLv3" + if strings.Contains(lineLower, "sslv2") { + proto = "SSLv2" + } + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: fmt.Sprintf("%s enabled", proto), + Description: line, + Severity: severityCritical, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } + + // Detect TLS 1.0/1.1 — medium + if (strings.Contains(lineLower, "tlsv1.0") || strings.Contains(lineLower, "tlsv1.1") || + strings.Contains(lineLower, "tls 1.0") || strings.Contains(lineLower, "tls 1.1")) && + strings.Contains(lineLower, "enabled") { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "Deprecated TLS version enabled", + Description: line, + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } +} + +// parseSSLScanCipherLine detects weak ciphers and weak DH key-exchange groups. +func (rp *ResultParser) parseSSLScanCipherLine(result *ScanResult, ip string, port int, line string, seenDH map[string]bool) { + lineLower := strings.ToLower(line) + + // Detect weak ciphers — high + weakCiphers := []string{"des", "rc4", "export", "null"} + for _, weak := range weakCiphers { + if strings.Contains(lineLower, weak) && !strings.Contains(lineLower, "not ") && !strings.Contains(lineLower, "disabled") { result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: "Deprecated TLS version enabled", + Host: ip, + Port: port, + Title: fmt.Sprintf("Weak cipher: %s", strings.ToUpper(weak)), Description: line, - Severity: "medium", + Severity: severityHigh, Source: "sslscan", Discovery: result.Timestamp, }) + break } + } - // Detect weak ciphers — high - weakCiphers := []string{"des", "rc4", "export", "null"} - for _, weak := range weakCiphers { - if strings.Contains(lineLower, weak) && !strings.Contains(lineLower, "not ") && !strings.Contains(lineLower, "disabled") { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: fmt.Sprintf("Weak cipher: %s", strings.ToUpper(weak)), - Description: line, - Severity: "high", - Source: "sslscan", - Discovery: result.Timestamp, - }) - break - } - } - - // Detect weak DH groups (1024 bits or less) — deduplicate per host. - if currentIP != "" && strings.Contains(line, "DHE") && strings.Contains(line, "bits") { - bitRegex := regexp.MustCompile(`DHE\s+(\d+)\s+bits`) - if matches := bitRegex.FindStringSubmatch(line); len(matches) > 1 { - bits, _ := strconv.Atoi(matches[1]) - if bits <= 1024 { - key := fmt.Sprintf("%s:%d", currentIP, bits) - if !seenDH[key] { - seenDH[key] = true - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: fmt.Sprintf("Weak DH key exchange group (%d bits)", bits), - Description: line, - Severity: "medium", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } + // Detect weak DH groups (1024 bits or less) — deduplicate per host. + if ip != "" && strings.Contains(line, "DHE") && strings.Contains(line, "bits") { + bitRegex := regexp.MustCompile(`DHE\s+(\d+)\s+bits`) + if matches := bitRegex.FindStringSubmatch(line); len(matches) > 1 { + bits, _ := strconv.Atoi(matches[1]) + if bits <= 1024 { + key := fmt.Sprintf("%s:%d", ip, bits) + if !seenDH[key] { + seenDH[key] = true + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: fmt.Sprintf("Weak DH key exchange group (%d bits)", bits), + Description: line, + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) } } } + } +} - // Detect certificate issues — medium - certIssues := []string{"self-signed", "expired", "weak key"} - for _, issue := range certIssues { - if strings.Contains(lineLower, issue) { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: fmt.Sprintf("Certificate issue: %s", issue), - Description: line, - Severity: "medium", - Source: "sslscan", - Discovery: result.Timestamp, - }) - break - } - } - - // Detect self-signed certificate from issuer line - if currentIP != "" && strings.Contains(lineLower, "issuer") && - strings.Contains(lineLower, "self-signed") { +// parseSSLScanCertLine detects certificate issues from sslscan text output. +func (rp *ResultParser) parseSSLScanCertLine(result *ScanResult, ip string, port int, line string) { + lineLower := strings.ToLower(line) + + // Detect certificate issues — medium + certIssues := []string{"self-signed", "expired", "weak key"} + for _, issue := range certIssues { + if strings.Contains(lineLower, issue) { result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: currentIP, - Port: currentPort, - Title: "Self-signed SSL certificate", + Host: ip, + Port: port, + Title: fmt.Sprintf("Certificate issue: %s", issue), Description: line, - Severity: "medium", + Severity: severityMedium, Source: "sslscan", Discovery: result.Timestamp, }) - } - - // Detect port from sslscan output - if strings.Contains(line, "SSL/TLS") && strings.Contains(line, ":") { - portRegex := regexp.MustCompile(`(\d+)`) - if matches := portRegex.FindStringSubmatch(line); len(matches) > 1 { - currentPort, _ = strconv.Atoi(matches[1]) - } + break } } - return result, nil + // Detect self-signed certificate from issuer line + if ip != "" && strings.Contains(lineLower, "issuer") && + strings.Contains(lineLower, "self-signed") { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "Self-signed SSL certificate", + Description: line, + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } } // --- sslscan XML types --- @@ -1463,97 +1508,111 @@ func (rp *ResultParser) parseSSLScanXML(result *ScanResult, content string) (*Sc port := test.Port - // SSLv2/SSLv3 — critical - for _, proto := range test.Protocols { - if proto.Enabled != "1" { - continue - } - if proto.Type == "ssl" && (proto.Version == "2" || proto.Version == "3") { - name := "SSLv3" - if proto.Version == "2" { - name = "SSLv2" - } - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: fmt.Sprintf("%s enabled", name), - Severity: "critical", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } - // TLS 1.0/1.1 — medium - if proto.Type == "tls" && (proto.Version == "1.0" || proto.Version == "1.1") { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: "Deprecated TLS version enabled", - Description: fmt.Sprintf("TLS %s is enabled", proto.Version), - Severity: "medium", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } - } + // Protocol-based vulnerabilities (SSLv2/v3, deprecated TLS) + rp.detectSSLProtocolVulns(result, ip, port, test.Protocols) - // Heartbleed - for _, hb := range test.Heartbleed { - if hb.Vulnerable == "1" { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: "OpenSSL Heartbleed", - Severity: "high", - Source: "sslscan", - Discovery: result.Timestamp, - }) - break // only report once per host + // Heartbleed + weak DH group detection + rp.detectSSLCipherVulns(result, ip, port, test.Heartbleed, test.Ciphers) + + // Certificate issues + rp.detectSSLCertVulns(result, ip, port, test.Certs) + } + + return result, nil +} + +// detectSSLProtocolVulns reports vulnerabilities for enabled SSLv2/v3 and deprecated TLS versions. +func (rp *ResultParser) detectSSLProtocolVulns(result *ScanResult, ip string, port int, protocols []sslscanProtocol) { + for _, proto := range protocols { + if proto.Enabled != "1" { + continue + } + if proto.Type == "ssl" && (proto.Version == "2" || proto.Version == "3") { + name := "SSLv3" + if proto.Version == "2" { + name = "SSLv2" } + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: fmt.Sprintf("%s enabled", name), + Severity: severityCritical, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } + if proto.Type == "tls" && (proto.Version == "1.0" || proto.Version == "1.1") { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "Deprecated TLS version enabled", + Description: fmt.Sprintf("TLS %s is enabled", proto.Version), + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) } + } +} - // Weak DH groups (deduplicate by bit size) - seenDHE := make(map[int]bool) - for _, cipher := range test.Ciphers { - if cipher.DHEBits > 0 && cipher.DHEBits <= 1024 && !seenDHE[cipher.DHEBits] { - seenDHE[cipher.DHEBits] = true - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: fmt.Sprintf("Weak DH key exchange group (%d bits)", cipher.DHEBits), - Severity: "medium", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } +// detectSSLCipherVulns reports Heartbleed and weak DH key-exchange vulnerabilities. +func (rp *ResultParser) detectSSLCipherVulns(result *ScanResult, ip string, port int, heartbleeds []sslscanHeartbleed, ciphers []sslscanCipher) { + for _, hb := range heartbleeds { + if hb.Vulnerable == "1" { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "OpenSSL Heartbleed", + Severity: severityHigh, + Source: "sslscan", + Discovery: result.Timestamp, + }) + break } + } - // Certificate issues - if test.Certs != nil { - cert := test.Certs.Certificate - if cert.SelfSigned == "true" { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: "Self-signed SSL certificate", - Severity: "medium", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } - if cert.Expired == "true" { - result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ - Host: ip, - Port: port, - Title: "Expired SSL certificate", - Severity: "high", - Source: "sslscan", - Discovery: result.Timestamp, - }) - } + seenDHE := make(map[int]bool) + for _, cipher := range ciphers { + if cipher.DHEBits > 0 && cipher.DHEBits <= 1024 && !seenDHE[cipher.DHEBits] { + seenDHE[cipher.DHEBits] = true + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: fmt.Sprintf("Weak DH key exchange group (%d bits)", cipher.DHEBits), + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) } } +} - return result, nil +// detectSSLCertVulns reports self-signed and expired certificate vulnerabilities. +func (rp *ResultParser) detectSSLCertVulns(result *ScanResult, ip string, port int, certs *sslscanCerts) { + if certs == nil { + return + } + cert := certs.Certificate + if cert.SelfSigned == "true" { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "Self-signed SSL certificate", + Severity: severityMedium, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } + if cert.Expired == "true" { + result.Vulnerabilities = append(result.Vulnerabilities, Vulnerability{ + Host: ip, + Port: port, + Title: "Expired SSL certificate", + Severity: severityHigh, + Source: "sslscan", + Discovery: result.Timestamp, + }) + } } // --- LLDP/CDP XML types --- @@ -1826,10 +1885,10 @@ func (rp *ResultParser) parseExploitSearchResult(result *ScanResult, content str for _, svc := range h.Services { for _, expl := range svc.Exploits { - exploitSeverity := "high" + exploitSeverity := severityHigh switch strings.ToLower(expl.Type) { case "dos", "local": - exploitSeverity = "medium" + exploitSeverity = severityMedium } vuln := Vulnerability{ Host: h.IP, diff --git a/internal/correlation/parsers_test.go b/internal/correlation/parsers_test.go index 84e3ba4..1cd0f16 100644 --- a/internal/correlation/parsers_test.go +++ b/internal/correlation/parsers_test.go @@ -379,7 +379,7 @@ PORT STATE SERVICE 80/tcp open http ` - if err := os.WriteFile(resultFile, []byte(content), 0644); err != nil { + if err := os.WriteFile(resultFile, []byte(content), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } @@ -415,7 +415,7 @@ func TestScanWorkspaceForResults(t *testing.T) { for filename, content := range files { filePath := filepath.Join(tempDir, filename) - if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } } @@ -701,7 +701,7 @@ func TestScanWorkspaceForResults_SubdirectoryScan(t *testing.T) { dir := t.TempDir() // Two levels deep — filepath.Glob("**") would not recurse this far subdir := filepath.Join(dir, "scan-2024-01-15", "subscans") - if err := os.MkdirAll(subdir, 0755); err != nil { + if err := os.MkdirAll(subdir, 0750); err != nil { t.Fatalf("MkdirAll: %v", err) } nmapContent := `Starting Nmap 7.94 @@ -711,7 +711,7 @@ PORT STATE SERVICE 22/tcp open ssh Nmap done: 1 IP address (1 host up) scanned` nmapFile := filepath.Join(subdir, "result.nmap") - if err := os.WriteFile(nmapFile, []byte(nmapContent), 0644); err != nil { + if err := os.WriteFile(nmapFile, []byte(nmapContent), 0600); err != nil { t.Fatalf("WriteFile: %v", err) } rp := NewResultParser(dir) @@ -730,7 +730,7 @@ func TestScanWorkspaceSkipsLargeFiles(t *testing.T) { // Create a small file smallFile := filepath.Join(tempDir, "small.txt") - if err := os.WriteFile(smallFile, []byte("192.168.1.1"), 0644); err != nil { + if err := os.WriteFile(smallFile, []byte("192.168.1.1"), 0600); err != nil { t.Fatalf("Failed to create small file: %v", err) } @@ -842,6 +842,36 @@ func TestParseNiktoXMLResultInvalid(t *testing.T) { _ = parsed } +func findVuln(t *testing.T, vulns []Vulnerability, match func(Vulnerability) bool, msg string) { + t.Helper() + for _, v := range vulns { + if match(v) { + return + } + } + t.Errorf("expected vulnerability: %s", msg) +} + +func assertNoVulnForHost(t *testing.T, vulns []Vulnerability, host, msg string) { + t.Helper() + for _, v := range vulns { + if v.Host == host { + t.Error(msg) + return + } + } +} + +func assertAllVulnSource(t *testing.T, vulns []Vulnerability, want string) { + t.Helper() + for _, v := range vulns { + if v.Source != want { + t.Errorf("Source = %q, want %q", v.Source, want) + return + } + } +} + func TestParseSSLScanResult(t *testing.T) { parser := NewResultParser("") @@ -881,70 +911,26 @@ SSL/TLS: t.Fatalf("parseSSLScanResult() error = %v", err) } - // Should find 2 hosts if len(parsed.Hosts) < 1 { - t.Error("Should parse at least 1 host") - } - - // Should detect SSLv3 as critical - var foundSSLv3 bool - for _, v := range parsed.Vulnerabilities { - if v.Host == "192.168.1.1" && v.Severity == "critical" && strings.Contains(v.Title, "SSLv3") { - foundSSLv3 = true - } - } - if !foundSSLv3 { - t.Error("Should detect SSLv3 as critical") - } - - // Should detect TLS 1.0 as medium - var foundTLS10 bool - for _, v := range parsed.Vulnerabilities { - if v.Host == "192.168.1.1" && v.Severity == "medium" && strings.Contains(v.Title, "Deprecated") { - foundTLS10 = true - } - } - if !foundTLS10 { - t.Error("Should detect TLS 1.0 as deprecated") - } - - // Should detect weak ciphers - var foundWeak bool - for _, v := range parsed.Vulnerabilities { - if v.Host == "192.168.1.1" && v.Severity == "high" && v.Source == "sslscan" { - foundWeak = true - } - } - if !foundWeak { - t.Error("Should detect weak cipher (DES/RC4)") - } - - // Should detect self-signed cert - var foundSelfSigned bool - for _, v := range parsed.Vulnerabilities { - if v.Host == "192.168.1.1" && strings.Contains(strings.ToLower(v.Title), "self-signed") { - foundSelfSigned = true - } - } - if !foundSelfSigned { - t.Error("Should detect self-signed certificate") - } - - // Check Source field - for _, v := range parsed.Vulnerabilities { - if v.Source != "sslscan" { - t.Errorf("Source = %q, want 'sslscan'", v.Source) - break - } - } - - // Second host should be clean - for _, v := range parsed.Vulnerabilities { - if v.Host == "192.168.1.2" { - t.Error("Second host should have no findings (only TLS 1.2/1.3)") - break - } - } + t.Fatal("Should parse at least 1 host") + } + + vulns := parsed.Vulnerabilities + findVuln(t, vulns, func(v Vulnerability) bool { + return v.Host == "192.168.1.1" && v.Severity == "critical" && strings.Contains(v.Title, "SSLv3") + }, "Should detect SSLv3 as critical") + findVuln(t, vulns, func(v Vulnerability) bool { + return v.Host == "192.168.1.1" && v.Severity == "medium" && strings.Contains(v.Title, "Deprecated") + }, "Should detect TLS 1.0 as deprecated") + findVuln(t, vulns, func(v Vulnerability) bool { + return v.Host == "192.168.1.1" && v.Severity == "high" && v.Source == "sslscan" + }, "Should detect weak cipher (DES/RC4)") + findVuln(t, vulns, func(v Vulnerability) bool { + return v.Host == "192.168.1.1" && strings.Contains(strings.ToLower(v.Title), "self-signed") + }, "Should detect self-signed certificate") + assertAllVulnSource(t, vulns, "sslscan") + assertNoVulnForHost(t, vulns, "192.168.1.2", + "Second host should have no findings (only TLS 1.2/1.3)") } func TestDetermineScanTypeNikto(t *testing.T) { @@ -972,7 +958,7 @@ func TestScanWorkspaceForResults_SSLScan(t *testing.T) { // Create sslscan XML file in a supplementary directory suppDir := filepath.Join(tempDir, "supplementary") - if err := os.MkdirAll(suppDir, 0755); err != nil { + if err := os.MkdirAll(suppDir, 0750); err != nil { t.Fatal(err) } sslContent := ` @@ -981,7 +967,7 @@ func TestScanWorkspaceForResults_SSLScan(t *testing.T) { ` - if err := os.WriteFile(filepath.Join(suppDir, "sslscan_10.0.0.1.xml"), []byte(sslContent), 0644); err != nil { + if err := os.WriteFile(filepath.Join(suppDir, "sslscan_10.0.0.1.xml"), []byte(sslContent), 0600); err != nil { t.Fatal(err) } diff --git a/internal/correlation/screenshot.go b/internal/correlation/screenshot.go index c3801be..c0e2ee0 100644 --- a/internal/correlation/screenshot.go +++ b/internal/correlation/screenshot.go @@ -169,11 +169,11 @@ func FindScreenshotsOnDisk(workspaceDir string) map[string][]ScreenshotInfo { func parseScreenshotJSONL(jsonlPath string) []ScreenshotInfo { screenshots := make([]ScreenshotInfo, 0) - file, err := os.Open(jsonlPath) + file, err := os.Open(jsonlPath) //nolint:gosec // G304: path from trusted workspace if err != nil { return screenshots } - defer file.Close() + defer func() { _ = file.Close() }() scanner := bufio.NewScanner(file) scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) // up to 10 MB per line diff --git a/internal/correlation/screenshot_test.go b/internal/correlation/screenshot_test.go index 6f8cd14..0e92be0 100644 --- a/internal/correlation/screenshot_test.go +++ b/internal/correlation/screenshot_test.go @@ -244,7 +244,7 @@ func TestFindScreenshotsOnDisk(t *testing.T) { t.Run("workspace with screenshot JSONL file", func(t *testing.T) { tmpDir := t.TempDir() screenshotsDir := filepath.Join(tmpDir, "captures", "screenshots", "20250120_120000") - if err := os.MkdirAll(screenshotsDir, 0755); err != nil { + if err := os.MkdirAll(screenshotsDir, 0750); err != nil { t.Fatalf("Failed to create screenshots directory: %v", err) } @@ -253,7 +253,7 @@ func TestFindScreenshotsOnDisk(t *testing.T) { jsonlContent := `{"url":"http://192.168.1.1","file_name":"http--192.168.1.1-80.jpeg","screenshot":"","response_code":200,"failed":false} {"url":"https://192.168.1.1","file_name":"https--192.168.1.1-443.jpeg","screenshot":"","response_code":200,"failed":false} {"url":"http://192.168.1.2","file_name":"http--192.168.1.2-80.jpeg","screenshot":"","response_code":404,"failed":false}` - if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0644); err != nil { + if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0600); err != nil { t.Fatalf("Failed to write JSONL file: %v", err) } @@ -393,7 +393,7 @@ func TestParseScreenshotJSONL(t *testing.T) { {"url":"https://192.168.1.1","final_url":"","file_name":"https--192.168.1.1-443.jpeg","screenshot":"","response_code":200,"failed":false} {"url":"http://192.168.1.2","final_url":"http://192.168.1.2/dashboard","file_name":"http--192.168.1.2-80.jpeg","screenshot":"","response_code":404,"failed":false}` - if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0644); err != nil { + if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0600); err != nil { t.Fatalf("Failed to write JSONL file: %v", err) } @@ -429,7 +429,7 @@ func TestParseScreenshotJSONL(t *testing.T) { invalid json line {"url":"https://192.168.1.1","file_name":"https--192.168.1.1-443.jpeg","screenshot":"","response_code":200,"failed":false}` - if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0644); err != nil { + if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0600); err != nil { t.Fatalf("Failed to write JSONL file: %v", err) } @@ -449,7 +449,7 @@ invalid json line {"url":"http://192.168.1.2","file_name":"","screenshot":"","response_code":0,"failed":true,"failed_reason":"connection refused"} {"url":"https://192.168.1.1","file_name":"https--192.168.1.1-443.jpeg","screenshot":"","response_code":200,"failed":false}` - if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0644); err != nil { + if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0600); err != nil { t.Fatalf("Failed to write JSONL file: %v", err) } @@ -475,7 +475,7 @@ invalid json line jsonlContent := `{"url":"http://192.168.1.1","final_url":"http://router.local/home","file_name":"http--192.168.1.1-80.jpeg","screenshot":"","response_code":200,"failed":false} {"url":"https://192.168.1.2","final_url":"https://device.local/","file_name":"https--192.168.1.2-443.jpeg","screenshot":"","response_code":301,"failed":false}` - if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0644); err != nil { + if err := os.WriteFile(jsonlPath, []byte(jsonlContent), 0600); err != nil { t.Fatalf("Failed to write JSONL file: %v", err) } diff --git a/internal/correlation/topology.go b/internal/correlation/topology.go index 31628d2..93fcec8 100644 --- a/internal/correlation/topology.go +++ b/internal/correlation/topology.go @@ -12,6 +12,11 @@ import ( "time" ) +const ( + vlanDefault = "default" + hostTypeRouter = "router" +) + // TopologyGenerator creates network topology visualizations from correlated data. type TopologyGenerator struct { workspaceDir string @@ -138,13 +143,13 @@ func (tg *TopologyGenerator) GenerateHTMLViewer(correlations map[string]*Correla html := topologyHTMLHead + topologyD3 + "\nconst VLANS = " + string(vlansData) + ";\nconst CONNECTIONS = " + string(connData) + ";\n" + topologyHTMLJS + "\n\n\n" outDir := filepath.Join(tg.workspaceDir, "topology") - if err := os.MkdirAll(outDir, 0o755); err != nil { + if err := os.MkdirAll(outDir, 0750); err != nil { return "", fmt.Errorf("creating topology dir: %w", err) } ts := time.Now().Format("20060102_150405") htmlPath := filepath.Join(outDir, "topology_viewer_"+ts+".html") - if err := os.WriteFile(htmlPath, []byte(html), 0o644); err != nil { + if err := os.WriteFile(htmlPath, []byte(html), 0o600); err != nil { return "", fmt.Errorf("writing HTML viewer: %w", err) } @@ -203,7 +208,7 @@ func vlanForHost(corr *CorrelationResult) string { if corr.HostInfo != nil && corr.HostInfo.IP != "" { return subnetFromIP(corr.HostInfo.IP) } - return "default" + return vlanDefault } func subnetFromIP(ip string) string { @@ -213,13 +218,13 @@ func subnetFromIP(ip string) string { if len(parts) >= 3 { return parts[0] + ":" + parts[1] + ":" + parts[2] + "::/48" } - return "default" + return vlanDefault } parts := strings.Split(ip, ".") if len(parts) >= 3 { return parts[0] + "." + parts[1] + "." + parts[2] + ".0/24" } - return "default" + return vlanDefault } func safeHost(corr *CorrelationResult) string { @@ -237,7 +242,7 @@ func vlanGroupName(vlan string, hosts []*CorrelationResult) string { } } } - if vlan != "default" { + if vlan != vlanDefault { return "VLAN " + vlan } return "Default Network" @@ -336,12 +341,12 @@ func isGateway(corr *CorrelationResult) bool { return false } dt := strings.ToLower(corr.HostInfo.Attributes["device_type"]) - if dt == "router" || dt == "gateway" || dt == "firewall" || dt == "layer3" || dt == "switch_l3" { + if dt == hostTypeRouter || dt == "gateway" || dt == "firewall" || dt == "layer3" || dt == "switch_l3" { return true } // Check capabilities attribute caps := strings.ToLower(corr.HostInfo.Attributes["capabilities"]) - if strings.Contains(caps, "router") || strings.Contains(caps, "gateway") { + if strings.Contains(caps, hostTypeRouter) || strings.Contains(caps, "gateway") { return true } return false diff --git a/internal/correlation/topology_test.go b/internal/correlation/topology_test.go index b10145c..04297a9 100644 --- a/internal/correlation/topology_test.go +++ b/internal/correlation/topology_test.go @@ -58,7 +58,7 @@ func TestGenerateHTMLViewer_BasicOutput(t *testing.T) { t.Errorf("expected .html suffix, got %q", htmlPath) } - data, err := os.ReadFile(htmlPath) + data, err := os.ReadFile(htmlPath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading HTML file: %v", err) } @@ -249,7 +249,7 @@ func TestGenerateHTMLViewer_IsLocal(t *testing.T) { t.Fatalf("GenerateHTMLViewer failed: %v", err) } - data, err := os.ReadFile(htmlPath) + data, err := os.ReadFile(htmlPath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading HTML: %v", err) } @@ -283,7 +283,7 @@ func TestGenerateHTMLViewer_MultiVLANLocalRemote(t *testing.T) { t.Fatalf("GenerateHTMLViewer failed: %v", err) } - data, err := os.ReadFile(htmlPath) + data, err := os.ReadFile(htmlPath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading HTML: %v", err) } @@ -329,7 +329,7 @@ func TestGenerateHTMLViewer_RiskAndVulnData(t *testing.T) { t.Fatalf("GenerateHTMLViewer failed: %v", err) } - data, err := os.ReadFile(htmlPath) + data, err := os.ReadFile(htmlPath) //nolint:gosec // G304: test path if err != nil { t.Fatalf("reading HTML: %v", err) } diff --git a/internal/executor/executor.go b/internal/executor/executor.go index 4f5fea1..653c3bb 100644 --- a/internal/executor/executor.go +++ b/internal/executor/executor.go @@ -80,14 +80,14 @@ func (e *Executor) ExecuteScript(scriptPath string, outputWriter io.Writer) (*Sc case <-e.cancel: // Command was cancelled if cmd.Process != nil { - cmd.Process.Kill() + _ = cmd.Process.Kill() } cmdErr = <-done // Wait for process to actually exit } // Close stdin if e.stdin != nil { - e.stdin.Close() + _ = e.stdin.Close() e.stdin = nil } diff --git a/internal/executor/streaming.go b/internal/executor/streaming.go index 3e8691a..3797b22 100644 --- a/internal/executor/streaming.go +++ b/internal/executor/streaming.go @@ -76,15 +76,15 @@ func (r *StreamingResult) SetFinal(success bool, exitCode int, err error, endTim } // GetFinal returns the final result fields atomically. -func (r *StreamingResult) GetFinal() (success bool, exitCode int, err error, duration time.Duration, endTime time.Time) { +func (r *StreamingResult) GetFinal() (success bool, exitCode int, duration time.Duration, endTime time.Time, err error) { r.mu.Lock() defer r.mu.Unlock() - return r.Success, r.ExitCode, r.Error, r.Duration, r.EndTime + return r.Success, r.ExitCode, r.Duration, r.EndTime, r.Error } // NewStreamingExecutor creates a new streaming executor func NewStreamingExecutor() *StreamingExecutor { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) //nolint:gosec // G118: cancel stored in struct return &StreamingExecutor{ ctx: ctx, cancel: cancel, @@ -228,7 +228,7 @@ func (e *StreamingExecutor) executeScript(scriptPath string, result *StreamingRe // Close stdin if e.stdin != nil { - e.stdin.Close() + _ = e.stdin.Close() e.stdin = nil } } diff --git a/internal/executor/streaming_test.go b/internal/executor/streaming_test.go index 4e490a0..509b6b1 100644 --- a/internal/executor/streaming_test.go +++ b/internal/executor/streaming_test.go @@ -14,7 +14,7 @@ func writeScript(t *testing.T, content string) string { t.Helper() dir := t.TempDir() path := filepath.Join(dir, "test.sh") - if err := os.WriteFile(path, []byte("#!/bin/sh\n"+content), 0755); err != nil { + if err := os.WriteFile(path, []byte("#!/bin/sh\n"+content), 0750); err != nil { //nolint:gosec // test script needs execute t.Fatalf("writeScript: %v", err) } return path @@ -93,7 +93,7 @@ func TestExecuteScriptStreaming_Cancellation(t *testing.T) { t.Fatal("timed out waiting for output") } } - e.Stop() + _ = e.Stop() e.Wait() if e.IsRunning() { t.Error("executor should not be running after Stop()") @@ -108,7 +108,7 @@ func TestExecuteScriptStreaming_AlreadyRunning(t *testing.T) { e := NewStreamingExecutor() _, outCh, _ := e.ExecuteScriptStreaming(script) defer func() { - e.Stop() + _ = e.Stop() e.Wait() for range outCh { } @@ -155,7 +155,7 @@ func TestSetFinalGetFinal_Concurrent(t *testing.T) { case <-deadline: t.Fatal("timed out waiting for SetFinal") default: - result.GetFinal() + _, _, _, _, _ = result.GetFinal() } } } diff --git a/internal/jobs/manager.go b/internal/jobs/manager.go index a1ca031..3c30151 100644 --- a/internal/jobs/manager.go +++ b/internal/jobs/manager.go @@ -2,6 +2,7 @@ package jobs import ( "fmt" + "os" "sort" "strconv" "strings" @@ -214,7 +215,7 @@ func (jm *JobManager) monitorJob(job *Job) { job.Duration = job.EndTime.Sub(job.StartTime) if job.Result != nil { - success, _, jobErr, _, _ := job.Result.GetFinal() + success, _, _, _, jobErr := job.Result.GetFinal() if success { job.Status = JobStatusCompleted } else { @@ -282,7 +283,9 @@ func (jm *JobManager) CancelJob(jobID string) error { // Stop the executor if executor != nil { - executor.Stop() + if err := executor.Stop(); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to stop executor: %v\n", err) + } } // Update job status fields under the job lock @@ -552,7 +555,7 @@ func (jm *JobManager) Stop() { for _, job := range jm.jobs { job.mu.RLock() if job.Status == JobStatusRunning && job.Executor != nil { - job.Executor.Stop() + _ = job.Executor.Stop() } job.mu.RUnlock() } diff --git a/internal/metadata/script.go b/internal/metadata/script.go index 440d259..63e37e0 100644 --- a/internal/metadata/script.go +++ b/internal/metadata/script.go @@ -162,7 +162,7 @@ func (r *ScriptRegistry) LoadMetadata() error { func (r *ScriptRegistry) loadScriptMetadata(metaPath string) (ScriptMetadata, error) { var metadata ScriptMetadata - data, err := os.ReadFile(metaPath) + data, err := os.ReadFile(metaPath) //nolint:gosec // G304: metaPath from trusted scripts dir if err != nil { return metadata, fmt.Errorf("failed to read file: %w", err) } diff --git a/internal/metadata/script_test.go b/internal/metadata/script_test.go index f23dcc0..cbc40e9 100644 --- a/internal/metadata/script_test.go +++ b/internal/metadata/script_test.go @@ -57,7 +57,7 @@ func TestLoadScriptMetadataValid(t *testing.T) { last_updated: "2025-01-01" ` - if err := os.WriteFile(metaFile, []byte(yamlContent), 0644); err != nil { + if err := os.WriteFile(metaFile, []byte(yamlContent), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } @@ -121,7 +121,7 @@ func TestLoadScriptMetadataMissingFields(t *testing.T) { tempDir := t.TempDir() metaFile := filepath.Join(tempDir, "test.meta.yaml") - if err := os.WriteFile(metaFile, []byte(tt.yamlContent), 0644); err != nil { + if err := os.WriteFile(metaFile, []byte(tt.yamlContent), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } @@ -139,7 +139,7 @@ func TestLoadScriptMetadataInvalidYAML(t *testing.T) { tempDir := t.TempDir() metaFile := filepath.Join(tempDir, "test.meta.yaml") - if err := os.WriteFile(metaFile, []byte("{invalid yaml"), 0644); err != nil { + if err := os.WriteFile(metaFile, []byte("{invalid yaml"), 0600); err != nil { t.Fatalf("Failed to create test file: %v", err) } @@ -510,12 +510,12 @@ func TestValidateScriptFileNotFound(t *testing.T) { func TestValidateScriptFileExists(t *testing.T) { tempDir := t.TempDir() categoryDir := filepath.Join(tempDir, "discovery") - if err := os.MkdirAll(categoryDir, 0755); err != nil { + if err := os.MkdirAll(categoryDir, 0750); err != nil { t.Fatalf("Failed to create category dir: %v", err) } scriptFile := filepath.Join(categoryDir, "test.sh") - if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0755); err != nil { + if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0750); err != nil { //nolint:gosec // test script needs execute t.Fatalf("Failed to create script file: %v", err) } @@ -564,7 +564,7 @@ func TestParameterStructure(t *testing.T) { func TestValidateScript_MissingTool(t *testing.T) { dir := t.TempDir() scriptFile := filepath.Join(dir, "scan.sh") - if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0755); err != nil { + if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0750); err != nil { //nolint:gosec // test script needs execute t.Fatal(err) } registry := NewScriptRegistry(dir) @@ -591,7 +591,7 @@ func TestValidateScript_MissingTool(t *testing.T) { func TestValidateScript_PresentTool(t *testing.T) { dir := t.TempDir() scriptFile := filepath.Join(dir, "scan.sh") - if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0755); err != nil { + if err := os.WriteFile(scriptFile, []byte("#!/bin/sh\necho test"), 0750); err != nil { //nolint:gosec // test script needs execute t.Fatal(err) } registry := NewScriptRegistry(dir) @@ -617,7 +617,7 @@ func TestLoadMetadataIntegration(t *testing.T) { // Create directory structure discoveryDir := filepath.Join(tempDir, "discovery") - if err := os.MkdirAll(discoveryDir, 0755); err != nil { + if err := os.MkdirAll(discoveryDir, 0750); err != nil { t.Fatalf("Failed to create discovery dir: %v", err) } @@ -645,7 +645,7 @@ func TestLoadMetadataIntegration(t *testing.T) { author: "Test" last_updated: "2025-01-01" ` - if err := os.WriteFile(metaFile, []byte(yamlContent), 0644); err != nil { + if err := os.WriteFile(metaFile, []byte(yamlContent), 0600); err != nil { t.Fatalf("Failed to create meta file: %v", err) } diff --git a/internal/ui/correlation.go b/internal/ui/correlation.go index 319fe4f..2a4eeaa 100644 --- a/internal/ui/correlation.go +++ b/internal/ui/correlation.go @@ -22,6 +22,17 @@ import ( "netutil/internal/correlation" ) +const ( + catUnknown = "unknown" + catWindows = "windows" + catLinux = "linux" + catNetworkDevice = "network_device" + colorGreen = "green" + colorYellow = "yellow" + colorGray = "gray" + portStatusOpen = "open" +) + // hostCategory returns the ph7 category from HostInfo.Attributes ("windows", // "linux", "network_device", or "unknown"). Falls back to "unknown" if unset. func hostCategory(result *correlation.CorrelationResult) string { @@ -30,7 +41,7 @@ func hostCategory(result *correlation.CorrelationResult) string { return cat } } - return "unknown" + return catUnknown } // hostVendor returns the ph7 vendor string, or "-" if absent. @@ -63,7 +74,7 @@ func hostOpenPorts(result *correlation.CorrelationResult) string { } var portNums []int for _, p := range result.HostInfo.Ports { - if p.State == "open" { + if p.State == portStatusOpen { portNums = append(portNums, p.Number) } } @@ -81,11 +92,11 @@ func hostOpenPorts(result *correlation.CorrelationResult) string { // categoryOrder returns a sort key for display order: windows=0, linux=1, network_device=2, unknown=3. func categoryOrder(cat string) int { switch cat { - case "windows": + case catWindows: return 0 - case "linux": + case catLinux: return 1 - case "network_device": + case catNetworkDevice: return 2 default: return 3 @@ -95,11 +106,11 @@ func categoryOrder(cat string) int { // categoryTcellColor returns the tcell display color for a category (for table cells). func categoryTcellColor(cat string) tcell.Color { switch cat { - case "windows": + case catWindows: return tcell.ColorGreen - case "linux": + case catLinux: return tcell.ColorYellow - case "network_device": + case catNetworkDevice: return tcell.ColorBlue default: return tcell.ColorGray @@ -109,14 +120,14 @@ func categoryTcellColor(cat string) tcell.Color { // categoryTviewColor returns the tview markup color name for a category (for TextView). func categoryTviewColor(cat string) string { switch cat { - case "windows": - return "green" - case "linux": - return "yellow" - case "network_device": + case catWindows: + return colorGreen + case catLinux: + return colorYellow + case catNetworkDevice: return "aqua" default: - return "gray" + return colorGray } } @@ -160,7 +171,7 @@ func hostMatchesText(ip string, result *correlation.CorrelationResult, text stri } } for _, p := range result.HostInfo.Ports { - if p.State != "open" { + if p.State != portStatusOpen { continue } if strings.Contains(strconv.Itoa(p.Number), lower) { @@ -174,7 +185,7 @@ func hostMatchesText(ip string, result *correlation.CorrelationResult, text stri } // filterCategories defines the cycling order for the category filter. -var filterCategories = []string{"", "windows", "linux", "network_device", "unknown"} +var filterCategories = []string{"", catWindows, catLinux, catNetworkDevice, catUnknown} // cycleCategoryFilter advances filterCategory to the next value in the cycle. func (cv *CorrelationViewer) cycleCategoryFilter() { @@ -443,24 +454,10 @@ func (cv *CorrelationViewer) updateHostsList() { } } -// updateDetailsPanel renders host identity, classification, and port data for the selected host. -func (cv *CorrelationViewer) updateDetailsPanel() { - if cv.selectedHost == "" { - cv.detailsPanel.SetText(cv.str.HostDetailsSelectPrompt) - return - } - - result, exists := cv.correlator.GetCorrelationForHost(cv.selectedHost) - if !exists { - cv.detailsPanel.SetText(cv.str.HostDetailsNoData) - return - } - - var b strings.Builder - - // --- Identity --- +// writeIdentitySection renders the host identity block (IP, MAC, hostname, NetBIOS, OS). +func (cv *CorrelationViewer) writeIdentitySection(b *strings.Builder, result *correlation.CorrelationResult) { b.WriteString(cv.str.HostDetailsIdentity) - b.WriteString(fmt.Sprintf("IP: [white]%s[::-]\n", result.Host)) + fmt.Fprintf(b, "IP: [white]%s[::-]\n", result.Host) mac := "-" hostname := "-" netbios := "-" @@ -481,13 +478,15 @@ func (cv *CorrelationViewer) updateDetailsPanel() { osStr = result.HostInfo.OS } } - b.WriteString(fmt.Sprintf("MAC: [white]%s[::-]\n", mac)) - b.WriteString(fmt.Sprintf("Hostname: [white]%s[::-]\n", hostname)) - b.WriteString(fmt.Sprintf("NetBIOS: [white]%s[::-]\n", netbios)) - b.WriteString(fmt.Sprintf("OS: [white]%s[::-]\n", osStr)) + fmt.Fprintf(b, "MAC: [white]%s[::-]\n", mac) + fmt.Fprintf(b, "Hostname: [white]%s[::-]\n", hostname) + fmt.Fprintf(b, "NetBIOS: [white]%s[::-]\n", netbios) + fmt.Fprintf(b, "OS: [white]%s[::-]\n", osStr) b.WriteString("\n") +} - // --- Classification --- +// writeClassificationSection renders the host classification block (category, vendor, confidence, TTL). +func (cv *CorrelationViewer) writeClassificationSection(b *strings.Builder, result *correlation.CorrelationResult) { b.WriteString(cv.str.HostDetailsClassification) cat := hostCategory(result) vendor := hostVendor(result) @@ -501,30 +500,32 @@ func (cv *CorrelationViewer) updateDetailsPanel() { score = s } } - b.WriteString(fmt.Sprintf("Category: [%s]%s[::-]\n", categoryTviewColor(cat), cat)) - b.WriteString(fmt.Sprintf("Vendor: [white]%s[::-]\n", vendor)) + fmt.Fprintf(b, "Category: [%s]%s[::-]\n", categoryTviewColor(cat), cat) + fmt.Fprintf(b, "Vendor: [white]%s[::-]\n", vendor) if score != "-" { - b.WriteString(fmt.Sprintf("Confidence: [white]%s[::-] (score %s)\n", confidence, score)) + fmt.Fprintf(b, "Confidence: [white]%s[::-] (score %s)\n", confidence, score) } else { - b.WriteString(fmt.Sprintf("Confidence: [white]%s[::-]\n", confidence)) + fmt.Fprintf(b, "Confidence: [white]%s[::-]\n", confidence) } if result.HostInfo != nil { if ttl, ok := result.HostInfo.Attributes["ttl_normalized"]; ok && ttl != "" { - b.WriteString(fmt.Sprintf("TTL: [white]%s[::-]\n", ttl)) + fmt.Fprintf(b, "TTL: [white]%s[::-]\n", ttl) } } b.WriteString("\n") +} - // --- Ports & Services --- +// writePortsSection renders the open ports and services block. +func (cv *CorrelationViewer) writePortsSection(b *strings.Builder, result *correlation.CorrelationResult) { var openPorts []correlation.Port if result.HostInfo != nil { for _, p := range result.HostInfo.Ports { - if p.State == "open" { + if p.State == portStatusOpen { openPorts = append(openPorts, p) } } } - b.WriteString(fmt.Sprintf(cv.str.FmtHostDetailsPorts, len(openPorts))) + fmt.Fprintf(b, cv.str.FmtHostDetailsPorts, len(openPorts)) if len(openPorts) == 0 { b.WriteString(cv.str.HostDetailsNoOpenPorts) } else { @@ -537,28 +538,49 @@ func (cv *CorrelationViewer) updateDetailsPanel() { if ver == "" { ver = "-" } - b.WriteString(fmt.Sprintf("[white]%d/%s[::-] %-8s %s\n", - p.Number, p.Protocol, svc, ver)) + fmt.Fprintf(b, "[white]%d/%s[::-] %-8s %s\n", + p.Number, p.Protocol, svc, ver) } } +} - // --- Screenshots --- +// writeScreenshotsSection renders the screenshots block. +func (cv *CorrelationViewer) writeScreenshotsSection(b *strings.Builder, result *correlation.CorrelationResult) { screenshots := correlation.GetScreenshotsForHost(result) - b.WriteString(fmt.Sprintf(cv.str.FmtHostDetailsScreenshots, len(screenshots))) + fmt.Fprintf(b, cv.str.FmtHostDetailsScreenshots, len(screenshots)) if len(screenshots) == 0 { b.WriteString(cv.str.HostDetailsNoScreenshots) } else { for i, ss := range screenshots { - statusColor := "green" + statusColor := colorGreen if ss.StatusCode != "200" { - statusColor = "yellow" + statusColor = colorYellow } - b.WriteString(fmt.Sprintf("[%s]%d.[:-] [white]%s[::-] [gray](%s)[::-]\n", - statusColor, i+1, ss.URL, ss.StatusCode)) + fmt.Fprintf(b, "[%s]%d.[:-] [white]%s[::-] [gray](%s)[::-]\n", + statusColor, i+1, ss.URL, ss.StatusCode) } b.WriteString(cv.str.HostDetailsPressS) } +} + +// updateDetailsPanel renders host identity, classification, and port data for the selected host. +func (cv *CorrelationViewer) updateDetailsPanel() { + if cv.selectedHost == "" { + cv.detailsPanel.SetText(cv.str.HostDetailsSelectPrompt) + return + } + + result, exists := cv.correlator.GetCorrelationForHost(cv.selectedHost) + if !exists { + cv.detailsPanel.SetText(cv.str.HostDetailsNoData) + return + } + var b strings.Builder + cv.writeIdentitySection(&b, result) + cv.writeClassificationSection(&b, result) + cv.writePortsSection(&b, result) + cv.writeScreenshotsSection(&b, result) cv.detailsPanel.SetText(b.String()) } @@ -643,9 +665,9 @@ func (cv *CorrelationViewer) openCategorizationModal() { } list := tview.NewList(). - AddItem(cv.str.CatModalWindows, "", '1', func() { applyCategory("windows") }). - AddItem(cv.str.CatModalLinux, "", '2', func() { applyCategory("linux") }). - AddItem(cv.str.CatModalNetDevice, "", '3', func() { applyCategory("network_device") }). + AddItem(cv.str.CatModalWindows, "", '1', func() { applyCategory(catWindows) }). + AddItem(cv.str.CatModalLinux, "", '2', func() { applyCategory(catLinux) }). + AddItem(cv.str.CatModalNetDevice, "", '3', func() { applyCategory(catNetworkDevice) }). AddItem(cv.str.BtnCancel, "", 'q', closeModal) list.SetBorder(true).SetTitle(fmt.Sprintf(cv.str.FmtCatModalTitle, ip)) @@ -984,13 +1006,7 @@ func (cv *CorrelationViewer) showScreenshotModal() { } return nil case 'o': - ss := screenshots[currentIdx] - cv.app.Suspend(func() { - cmd := exec.Command("xdg-open", ss.File) - if err := cmd.Run(); err != nil { - fmt.Fprintf(os.Stderr, "failed to open %s: %v\n", ss.File, err) - } - }) + cv.openScreenshotExternally(screenshots[currentIdx].File) return nil } } @@ -1003,6 +1019,15 @@ func (cv *CorrelationViewer) showScreenshotModal() { cv.app.ForceDraw() } +func (cv *CorrelationViewer) openScreenshotExternally(filePath string) { + cv.app.Suspend(func() { + cmd := exec.Command("xdg-open", filePath) + if err := cmd.Run(); err != nil { + fmt.Fprintf(os.Stderr, "failed to open %s: %v\n", filePath, err) + } + }) +} + // tcellCell holds a single character cell with foreground and background colors. type tcellCell struct { char rune @@ -1055,11 +1080,11 @@ func (cv *CorrelationViewer) loadScreenshot(path string) (image.Image, error) { return img, nil } - file, err := os.Open(path) + file, err := os.Open(path) //nolint:gosec // G304: path from trusted workspace if err != nil { return nil, fmt.Errorf("failed to open screenshot: %w", err) } - defer file.Close() + defer func() { _ = file.Close() }() img, _, err := image.Decode(file) if err != nil { diff --git a/internal/ui/dashboard.go b/internal/ui/dashboard.go index dae395f..77134b3 100644 --- a/internal/ui/dashboard.go +++ b/internal/ui/dashboard.go @@ -203,20 +203,20 @@ func (d *Dashboard) calculateStats(correlations map[string]*correlation.Correlat // updateStatsPanel renders the discovery statistics panel. func (d *Dashboard) updateStatsPanel(stats DashboardStats) { var content strings.Builder - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsHostsDiscovered, stats.TotalHosts)) - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsWindows, stats.HostsByCategory["windows"])) - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsLinux, stats.HostsByCategory["linux"])) - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsNetDevices, stats.HostsByCategory["network_device"])) - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsUnknown, stats.HostsByCategory["unknown"])) + fmt.Fprintf(&content, d.str.FmtDashStatsHostsDiscovered, stats.TotalHosts) + fmt.Fprintf(&content, d.str.FmtDashStatsWindows, stats.HostsByCategory[catWindows]) + fmt.Fprintf(&content, d.str.FmtDashStatsLinux, stats.HostsByCategory[catLinux]) + fmt.Fprintf(&content, d.str.FmtDashStatsNetDevices, stats.HostsByCategory[catNetworkDevice]) + fmt.Fprintf(&content, d.str.FmtDashStatsUnknown, stats.HostsByCategory[catUnknown]) content.WriteString("\n") - content.WriteString(fmt.Sprintf(d.str.FmtDashStatsServices, stats.TotalServices)) + fmt.Fprintf(&content, d.str.FmtDashStatsServices, stats.TotalServices) content.WriteString("\n") content.WriteString(d.str.DashJobsHeading) - content.WriteString(fmt.Sprintf(d.str.FmtDashJobsRunning, stats.RunningJobs, stats.MaxConcurrent)) - content.WriteString(fmt.Sprintf(d.str.FmtDashJobsCompleted, stats.CompletedJobs)) - content.WriteString(fmt.Sprintf(d.str.FmtDashJobsFailed, stats.FailedJobs)) + fmt.Fprintf(&content, d.str.FmtDashJobsRunning, stats.RunningJobs, stats.MaxConcurrent) + fmt.Fprintf(&content, d.str.FmtDashJobsCompleted, stats.CompletedJobs) + fmt.Fprintf(&content, d.str.FmtDashJobsFailed, stats.FailedJobs) if !stats.LastScanTime.IsZero() { - content.WriteString(fmt.Sprintf(d.str.FmtDashLastScan, stats.LastScanTime.Format("15:04"))) + fmt.Fprintf(&content, d.str.FmtDashLastScan, stats.LastScanTime.Format("15:04")) } d.statsPanel.SetText(content.String()) } @@ -225,7 +225,7 @@ func (d *Dashboard) updateStatsPanel(stats DashboardStats) { func (d *Dashboard) updateChartsPanel(stats DashboardStats) { var content strings.Builder - categories := []string{"windows", "linux", "network_device", "unknown"} + categories := []string{catWindows, catLinux, catNetworkDevice, catUnknown} maxCount := 0 for _, cat := range categories { if n := stats.HostsByCategory[cat]; n > maxCount { @@ -246,7 +246,7 @@ func (d *Dashboard) updateChartsPanel(stats DashboardStats) { bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) color := categoryTviewColor(cat) label := d.str.CategoryDisplayLabel(cat) - content.WriteString(fmt.Sprintf(d.str.FmtDashCategoryBar, color, label, color, bar, count)) + fmt.Fprintf(&content, d.str.FmtDashCategoryBar, color, label, color, bar, count) } } @@ -280,7 +280,7 @@ func (d *Dashboard) updateActivityPanel() { var prefix, color string switch status { case jobs.JobStatusRunning: - prefix, color = "●", "green" + prefix, color = "●", colorGreen case jobs.JobStatusCompleted: prefix, color = "✓", "blue" case jobs.JobStatusFailed: @@ -303,8 +303,8 @@ func (d *Dashboard) updateActivityPanel() { name = string([]rune(name)[:21]) + "…" } - content.WriteString(fmt.Sprintf("[%s]%s %s %s[::-]\n", - color, prefix, name, formatJobDuration(dur))) + fmt.Fprintf(&content, "[%s]%s %s %s[::-]\n", + color, prefix, name, formatJobDuration(dur)) } if len(allJobs) == 0 { @@ -324,7 +324,7 @@ var riskTiers = []struct { {700, "Critical", "red", tcell.ColorRed}, {500, "High", "orange", tcell.ColorOrange}, {200, "Medium", "yellow", tcell.ColorYellow}, - {0, "Low", "green", tcell.ColorGreen}, + {0, "Low", colorGreen, tcell.ColorGreen}, } // updateRiskPanel renders aggregate risk posture across all correlated hosts. @@ -371,13 +371,13 @@ func (d *Dashboard) updateRiskPanel(correlations map[string]*correlation.Correla content.WriteString(d.str.DashRiskDistHeading) for _, tier := range riskTiers { count := tierCounts[tier.label] - content.WriteString(fmt.Sprintf(d.str.FmtDashRiskTierLine, tier.tviewColor, d.str.RiskLabel(tier.label), count)) + fmt.Fprintf(&content, d.str.FmtDashRiskTierLine, tier.tviewColor, d.str.RiskLabel(tier.label), count) } content.WriteString(d.str.DashSevSummaryHeading) for _, sev := range []string{"critical", "high", "medium", "low", "info"} { if count := severityCounts[sev]; count > 0 { - content.WriteString(fmt.Sprintf(" [%s]%-10s %d[::-]\n", severityTviewColor(sev), d.str.SeverityLabel(sev)+":", count)) + fmt.Fprintf(&content, " [%s]%-10s %d[::-]\n", severityTviewColor(sev), d.str.SeverityLabel(sev)+":", count) } } @@ -385,17 +385,17 @@ func (d *Dashboard) updateRiskPanel(correlations map[string]*correlation.Correla if niktoCount > 0 || sslCount > 0 { content.WriteString(d.str.DashBySourceHeading) if niktoCount > 0 { - content.WriteString(fmt.Sprintf(d.str.FmtDashNiktoFindings, niktoCount)) + fmt.Fprintf(&content, d.str.FmtDashNiktoFindings, niktoCount) } if sslCount > 0 { - content.WriteString(fmt.Sprintf(d.str.FmtDashSSLIssues, sslCount)) + fmt.Fprintf(&content, d.str.FmtDashSSLIssues, sslCount) } } avgScore := totalScore / len(correlations) - content.WriteString(fmt.Sprintf(d.str.FmtDashAvgScore, avgScore)) + fmt.Fprintf(&content, d.str.FmtDashAvgScore, avgScore) if highestIP != "" { - content.WriteString(fmt.Sprintf(d.str.FmtDashHighestRisk, highestIP, highestScore)) + fmt.Fprintf(&content, d.str.FmtDashHighestRisk, highestIP, highestScore) } d.riskPanel.SetText(content.String()) @@ -556,13 +556,13 @@ func (d *Dashboard) updateServicesPanel(correlations map[string]*correlation.Cor } for i := 0; i < maxShow; i++ { s := svcs[i] - content.WriteString(fmt.Sprintf(d.str.FmtDashServiceEntry, s.name, s.count)) + fmt.Fprintf(&content, d.str.FmtDashServiceEntry, s.name, s.count) } content.WriteString(d.str.DashPortsHeading) - content.WriteString(fmt.Sprintf(d.str.FmtDashUniqueOpenPorts, len(portSet))) + fmt.Fprintf(&content, d.str.FmtDashUniqueOpenPorts, len(portSet)) if maxPortHost != "" { - content.WriteString(fmt.Sprintf(d.str.FmtDashMostExposedHost, maxPortHost, maxPortCount)) + fmt.Fprintf(&content, d.str.FmtDashMostExposedHost, maxPortHost, maxPortCount) } d.servicesPanel.SetText(content.String()) @@ -630,42 +630,14 @@ func severityTviewColor(severity string) string { case "medium": return "yellow" case "low": - return "green" + return colorGreen default: - return "gray" + return colorGray } } -// showHostDetailsModal displays a scrollable vulnerability-focused detail view for the selected host. -func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.CorrelationResult) { - cat := hostCategory(corr) - hostname := hostHostname(corr) - osLabel := "-" - if corr.HostInfo != nil { - if corr.HostInfo.OSDetails != "" { - osLabel = corr.HostInfo.OSDetails - } else if corr.HostInfo.OS != "" { - osLabel = corr.HostInfo.OS - } - } - - // Header with OS context - var headerExtra string - if cat != "unknown" { - headerExtra = fmt.Sprintf(" — [%s]%s[::-]", categoryTviewColor(cat), cat) - } - if osLabel != "-" { - headerExtra += fmt.Sprintf(" [gray](%s)[::-]", osLabel) - } - var details strings.Builder - if hostname != "-" { - details.WriteString(fmt.Sprintf(d.str.FmtHostRiskDetailWithHost, hostIP, hostname, headerExtra)) - } else { - details.WriteString(fmt.Sprintf(d.str.FmtHostRiskDetail, hostIP, headerExtra)) - } - details.WriteString("\n\n") - - // Risk score with tier +// renderRiskScore writes the risk score tier and breakdown to details. +func (d *Dashboard) renderRiskScore(details *strings.Builder, corr *correlation.CorrelationResult) { var tierLabel, tierColor string for _, tier := range riskTiers { if corr.RiskScore >= tier.minScore { @@ -674,14 +646,14 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl break } } - details.WriteString(fmt.Sprintf(d.str.FmtRiskScore, tierColor, corr.RiskScore, tierColor, d.str.RiskLabel(tierLabel))) + fmt.Fprintf(details, d.str.FmtRiskScore, tierColor, corr.RiskScore, tierColor, d.str.RiskLabel(tierLabel)) // Risk breakdown bd := corr.RiskDetails - details.WriteString(fmt.Sprintf(d.str.FmtRiskBreakdownVulns, bd.VulnerabilityScore)) - details.WriteString(fmt.Sprintf(d.str.FmtRiskBreakdownService, bd.ServiceExposure)) - details.WriteString(fmt.Sprintf(d.str.FmtRiskBreakdownSSL, bd.SSLIssues)) - details.WriteString(fmt.Sprintf(d.str.FmtRiskBreakdownPorts, bd.OpenPortScore)) + fmt.Fprintf(details, d.str.FmtRiskBreakdownVulns, bd.VulnerabilityScore) + fmt.Fprintf(details, d.str.FmtRiskBreakdownService, bd.ServiceExposure) + fmt.Fprintf(details, d.str.FmtRiskBreakdownSSL, bd.SSLIssues) + fmt.Fprintf(details, d.str.FmtRiskBreakdownPorts, bd.OpenPortScore) // Risk factors by category factorCategories := []struct { @@ -699,7 +671,7 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl if len(catFactors) == 0 { continue } - details.WriteString(fmt.Sprintf(d.str.FmtRiskFactorCategory, cat.label, len(catFactors))) + fmt.Fprintf(details, d.str.FmtRiskFactorCategory, cat.label, len(catFactors)) maxFactors := len(catFactors) if maxFactors > 15 { maxFactors = 15 @@ -710,18 +682,20 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl if f.Source != "" { source = fmt.Sprintf(" (%s)", f.Source) } - details.WriteString(fmt.Sprintf(d.str.FmtRiskFactorLine, f.Title, f.Score, source)) + fmt.Fprintf(details, d.str.FmtRiskFactorLine, f.Title, f.Score, source) } if len(catFactors) > maxFactors { - details.WriteString(fmt.Sprintf(d.str.FmtAndMore, len(catFactors)-maxFactors)) + fmt.Fprintf(details, d.str.FmtAndMore, len(catFactors)-maxFactors) } } +} - // Vulnerabilities by severity +// renderVulnerabilities writes vulnerabilities grouped by severity to details. +func (d *Dashboard) renderVulnerabilities(details *strings.Builder, corr *correlation.CorrelationResult) { severities := []struct { sev string color string - }{{"critical", "red"}, {"high", "orange"}, {"medium", "yellow"}, {"low", "green"}, {"info", "gray"}} + }{{"critical", "red"}, {"high", "orange"}, {"medium", colorYellow}, {"low", colorGreen}, {"info", colorGray}} for _, s := range severities { var sevVulns []correlation.Vulnerability @@ -733,7 +707,7 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl if len(sevVulns) == 0 { continue } - details.WriteString(fmt.Sprintf(d.str.FmtSevFindings, s.color, d.str.SeverityLabel(s.sev))) + fmt.Fprintf(details, d.str.FmtSevFindings, s.color, d.str.SeverityLabel(s.sev)) maxShow := len(sevVulns) if maxShow > 15 { maxShow = 15 @@ -748,15 +722,17 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl } line += ")" } - details.WriteString(fmt.Sprintf("%s\n", line)) + fmt.Fprintf(details, "%s\n", line) } if len(sevVulns) > maxShow { - details.WriteString(fmt.Sprintf(d.str.FmtAndMore, len(sevVulns)-maxShow)) + fmt.Fprintf(details, d.str.FmtAndMore, len(sevVulns)-maxShow) } details.WriteString("\n") } +} - // Open ports summary +// renderOpenPorts writes the open ports summary to details. +func (d *Dashboard) renderOpenPorts(details *strings.Builder, corr *correlation.CorrelationResult) { var openPorts []string if corr.HostInfo != nil { for _, p := range corr.HostInfo.Ports { @@ -766,8 +742,46 @@ func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.Correl } } if len(openPorts) > 0 { - details.WriteString(fmt.Sprintf(d.str.FmtHostOpenPorts, strings.Join(openPorts, ", "))) + fmt.Fprintf(details, d.str.FmtHostOpenPorts, strings.Join(openPorts, ", ")) } +} + +// showHostDetailsModal displays a scrollable vulnerability-focused detail view for the selected host. +func (d *Dashboard) showHostDetailsModal(hostIP string, corr *correlation.CorrelationResult) { + cat := hostCategory(corr) + hostname := hostHostname(corr) + osLabel := "-" + if corr.HostInfo != nil { + if corr.HostInfo.OSDetails != "" { + osLabel = corr.HostInfo.OSDetails + } else if corr.HostInfo.OS != "" { + osLabel = corr.HostInfo.OS + } + } + + // Header with OS context + var headerExtra string + if cat != catUnknown { + headerExtra = fmt.Sprintf(" — [%s]%s[::-]", categoryTviewColor(cat), cat) + } + if osLabel != "-" { + headerExtra += fmt.Sprintf(" [gray](%s)[::-]", osLabel) + } + var details strings.Builder + if hostname != "-" { + fmt.Fprintf(&details, d.str.FmtHostRiskDetailWithHost, hostIP, hostname, headerExtra) + } else { + fmt.Fprintf(&details, d.str.FmtHostRiskDetail, hostIP, headerExtra) + } + details.WriteString("\n\n") + + // Risk score with tier + d.renderRiskScore(&details, corr) + // Vulnerabilities by severity + d.renderVulnerabilities(&details, corr) + + // Open ports summary + d.renderOpenPorts(&details, corr) // Build scrollable view textView := tview.NewTextView(). diff --git a/internal/ui/output.go b/internal/ui/output.go index b8ae1c5..6464e62 100644 --- a/internal/ui/output.go +++ b/internal/ui/output.go @@ -2,6 +2,7 @@ package ui import ( "fmt" + "os" "strings" "sync" "time" @@ -365,7 +366,7 @@ func (ov *OutputViewer) pollJobOutput(job *jobs.Job, startIdx int) { ov.running = false ov.completed = true status := "Completed" - statusColor := "green" + statusColor := colorGreen duration := time.Duration(0) if job.Result != nil { duration = job.Result.Duration @@ -419,7 +420,7 @@ func (ov *OutputViewer) processOutput() { if ov.result != nil { status := "Completed" - statusColor := "green" + statusColor := colorGreen if !ov.result.Success { status = "Failed" statusColor = "red" @@ -557,11 +558,6 @@ func (ov *OutputViewer) handlePromptDetection(line executor.OutputLine) { } } -func (ov *OutputViewer) updateDisplay() { - ov.mu.RLock() - defer ov.mu.RUnlock() - ov.updateDisplayLocked() -} func (ov *OutputViewer) updateDisplayLocked() { lines := ov.outputLines @@ -583,8 +579,8 @@ func (ov *OutputViewer) formatLinesLocked(lines []executor.OutputLine) string { for _, line := range lines { if ov.showTimestamp { - content.WriteString(fmt.Sprintf("[gray]%s[white] ", - line.Timestamp.Format("15:04:05"))) + fmt.Fprintf(&content, "[gray]%s[white] ", + line.Timestamp.Format("15:04:05")) } if ov.showSource { @@ -595,9 +591,9 @@ func (ov *OutputViewer) formatLinesLocked(lines []executor.OutputLine) string { case "error": color = "red" case "stdout": - color = "green" + color = colorGreen } - content.WriteString(fmt.Sprintf("[%s]%s[white] ", color, line.Source)) + fmt.Fprintf(&content, "[%s]%s[white] ", color, line.Source) } lineContent := tview.TranslateANSI(line.Content) @@ -619,6 +615,7 @@ func (ov *OutputViewer) formatLinesLocked(lines []executor.OutputLine) string { return content.String() } + func (ov *OutputViewer) ShowHistoricalOutput(jobName string, status jobs.JobStatus, lines []executor.OutputLine) { ov.mu.Lock() ov.outputLines = append(ov.outputLines, lines...) @@ -629,10 +626,11 @@ func (ov *OutputViewer) ShowHistoricalOutput(jobName string, status jobs.JobStat ov.updateDisplayLocked() ov.mu.Unlock() - statusColor := "green" - if status == jobs.JobStatusFailed { + statusColor := colorGreen + switch status { + case jobs.JobStatusFailed: statusColor = "red" - } else if status == jobs.JobStatusCancelled { + case jobs.JobStatusCancelled: statusColor = "gray" } @@ -648,7 +646,9 @@ func (ov *OutputViewer) cancelJob() { ov.mu.RUnlock() if jobID != "" && ov.jobManager != nil { - ov.jobManager.CancelJob(jobID) + if err := ov.jobManager.CancelJob(jobID); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to cancel job %s: %v\n", jobID, err) + } } } diff --git a/internal/ui/tui.go b/internal/ui/tui.go index 8ae0104..87f83b8 100644 --- a/internal/ui/tui.go +++ b/internal/ui/tui.go @@ -297,7 +297,9 @@ func mergeInterfaceTasks(categories []Category, str *Strings) []Category { ifaceIdx = len(newTasks) } newTasks = append(newTasks[:ifaceIdx:ifaceIdx], append([]Task{composite}, newTasks[ifaceIdx:]...)...) - categories[ci].Tasks = newTasks + if ci < len(categories) { + categories[ci].Tasks = newTasks + } break } return categories @@ -311,65 +313,88 @@ func mergeCaptureAnalysisTasks(categories []Category, str *Strings) []Category { if cat.Name != str.CatNetworkDiscovery { continue } - var vlanTask, macTask, captureTask Task - firstIdx := -1 - for ti, task := range cat.Tasks { - switch task.CanonicalName { - case "Extract VLANs": - vlanTask = task - if firstIdx == -1 { - firstIdx = ti - } - case "MAC Address Analysis": - macTask = task - if firstIdx == -1 { - firstIdx = ti - } - case "Packet Capture Analysis": - captureTask = task - if firstIdx == -1 { - firstIdx = ti - } - } + if newTasks := buildMergedCaptureTasks(cat.Tasks, str); newTasks != nil { + categories[ci].Tasks = newTasks } - if firstIdx == -1 || vlanTask.CanonicalName == "" || macTask.CanonicalName == "" || captureTask.CanonicalName == "" { - continue - } - newTasks := make([]Task, 0, len(cat.Tasks)-2) - for _, t := range cat.Tasks { - if t.CanonicalName != "Extract VLANs" && t.CanonicalName != "MAC Address Analysis" && t.CanonicalName != "Packet Capture Analysis" { - newTasks = append(newTasks, t) + break + } + return categories +} + +// mergeCaptureNames lists the canonical task names that get merged into the composite. +var mergeCaptureNames = []string{"Extract VLANs", "MAC Address Analysis", "Packet Capture Analysis"} + +// buildMergedCaptureTasks merges the VLAN/MAC/PacketCapture tasks into a single +// composite "Network Capture Analysis" task. Returns nil if the required tasks +// are not all present. +func buildMergedCaptureTasks(tasks []Task, str *Strings) []Task { + var vlanTask, macTask, captureTask Task + firstIdx := -1 + for ti, task := range tasks { + switch task.CanonicalName { + case "Extract VLANs": + vlanTask = task + if firstIdx == -1 { + firstIdx = ti + } + case "MAC Address Analysis": + macTask = task + if firstIdx == -1 { + firstIdx = ti + } + case "Packet Capture Analysis": + captureTask = task + if firstIdx == -1 { + firstIdx = ti } } - composite := Task{ - Name: str.TaskNetworkCaptureAnalysis, - CanonicalName: "Network Capture Analysis", - Description: str.TaskNetworkCaptureAnalysisDesc, - SubTasks: []Task{ - {Name: str.TaskExtractVLANIDs, CanonicalName: "Extract VLANs", Description: vlanTask.Description, Script: vlanTask.Script}, - {Name: macTask.Name, CanonicalName: "MAC Address Analysis", Description: macTask.Description, Script: macTask.Script}, - {Name: captureTask.Name, CanonicalName: "Packet Capture Analysis", Description: captureTask.Description, Script: captureTask.Script}, - }, - } - // Insert composite immediately after "Network Capture" if present, otherwise at firstIdx - insertIdx := -1 - for ti, t := range newTasks { - if t.CanonicalName == "Network Capture" { - insertIdx = ti + 1 + } + if firstIdx == -1 || vlanTask.CanonicalName == "" || macTask.CanonicalName == "" || captureTask.CanonicalName == "" { + return nil + } + + newTasks := make([]Task, 0, len(tasks)-2) + for _, t := range tasks { + isMergeTarget := false + for _, name := range mergeCaptureNames { + if t.CanonicalName == name { + isMergeTarget = true break } } - if insertIdx == -1 { - insertIdx = firstIdx - if insertIdx > len(newTasks) { - insertIdx = len(newTasks) - } + if !isMergeTarget { + newTasks = append(newTasks, t) } - newTasks = append(newTasks[:insertIdx:insertIdx], append([]Task{composite}, newTasks[insertIdx:]...)...) - categories[ci].Tasks = newTasks - break } - return categories + + composite := Task{ + Name: str.TaskNetworkCaptureAnalysis, + CanonicalName: "Network Capture Analysis", + Description: str.TaskNetworkCaptureAnalysisDesc, + SubTasks: []Task{ + {Name: str.TaskExtractVLANIDs, CanonicalName: "Extract VLANs", Description: vlanTask.Description, Script: vlanTask.Script}, + {Name: macTask.Name, CanonicalName: "MAC Address Analysis", Description: macTask.Description, Script: macTask.Script}, + {Name: captureTask.Name, CanonicalName: "Packet Capture Analysis", Description: captureTask.Description, Script: captureTask.Script}, + }, + } + + insertIdx := findInsertIndex(newTasks, firstIdx, "Network Capture") + newTasks = append(newTasks[:insertIdx:insertIdx], append([]Task{composite}, newTasks[insertIdx:]...)...) + return newTasks +} + +// findInsertIndex returns the position after the task with afterCanonicalName, +// or falls back to defaultIdx clamped to the slice length. +func findInsertIndex(tasks []Task, defaultIdx int, afterCanonicalName string) int { + for ti, t := range tasks { + if t.CanonicalName == afterCanonicalName { + return ti + 1 + } + } + if defaultIdx > len(tasks) { + return len(tasks) + } + return defaultIdx } // ensureTrueColor sets COLORTERM=truecolor if not already set, enabling @@ -538,7 +563,7 @@ func (t *TUI) updateJobsPanel() { // Duration for running jobs if status == jobs.JobStatusRunning { dur := time.Since(job.StartTime).Round(time.Second) - sb.WriteString(fmt.Sprintf(" %v", dur)) + fmt.Fprintf(&sb, " %v", dur) } sb.WriteString("\n") @@ -553,7 +578,7 @@ func (t *TUI) updateJobsPanel() { sb.WriteString(renderProgressBar(current, total, desc)) } else { idx := int(time.Now().Unix()) % len(indicatorChars) - sb.WriteString(fmt.Sprintf(" %s %s", indicatorChars[idx], t.str.ProgressRunning)) + fmt.Fprintf(&sb, " %s %s", indicatorChars[idx], t.str.ProgressRunning) } } sb.WriteString("\n") @@ -618,7 +643,7 @@ func (t *TUI) checkSysConfigDone() bool { if name == "lo" { continue } - state, err := os.ReadFile(filepath.Join("/sys/class/net", name, "operstate")) + state, err := os.ReadFile(filepath.Join("/sys/class/net", name, "operstate")) //nolint:gosec // G304: kernel path if err != nil || strings.TrimSpace(string(state)) != "up" { continue } @@ -847,39 +872,49 @@ func (t *TUI) setupUI() { t.setActiveFocus(t.categoryPane) } -func (t *TUI) handleGlobalKeys(event *tcell.EventKey) *tcell.EventKey { - // Consume Ctrl+C to prevent tview's built-in handler from stopping the application. - // tview calls a.Stop() on unhandled Ctrl+C (application.go ~L433). All pages handle - // their own cancellation — the output viewer uses it to stop streaming, the main page - // has no use for it. Either way the TUI must stay alive. - if event.Key() == tcell.KeyCtrlC { - if t.outputViewer != nil { - t.outputViewer.CancelAndReturn() - } - return nil - } - // Handle global Ctrl+key shortcuts that work everywhere (including output viewer) +// handleGlobalCtrlShortcuts processes global Ctrl+key shortcuts that work on every page +// (including the output viewer). Returns true if the event was consumed. +func (t *TUI) handleGlobalCtrlShortcuts(event *tcell.EventKey) bool { if event.Key() == tcell.KeyCtrlJ { // Global Job Manager access - works even during script execution t.showJobsManager() - return nil + return true } if event.Key() == tcell.KeyCtrlD { // Global Dashboard access t.showDashboard() - return nil + return true } if event.Key() == tcell.KeyCtrlN { // Global Host view access t.showCorrelationViewer() - return nil + return true } // Ctrl+Z: tcell puts the terminal in raw mode (ISIG cleared), // so the kernel never converts Ctrl+Z to SIGTSTP — safe to use as a keybind. if event.Key() == tcell.KeyCtrlZ { // Global return to main TUI from anywhere t.returnToMain() + return true + } + return false +} + +func (t *TUI) handleGlobalKeys(event *tcell.EventKey) *tcell.EventKey { + // Consume Ctrl+C to prevent tview's built-in handler from stopping the application. + // tview calls a.Stop() on unhandled Ctrl+C (application.go ~L433). All pages handle + // their own cancellation — the output viewer uses it to stop streaming, the main page + // has no use for it. Either way the TUI must stay alive. + if event.Key() == tcell.KeyCtrlC { + if t.outputViewer != nil { + t.outputViewer.CancelAndReturn() + } + return nil + } + + // Handle global Ctrl+key shortcuts that work everywhere (including output viewer) + if t.handleGlobalCtrlShortcuts(event) { return nil } @@ -1075,7 +1110,9 @@ func (t *TUI) executeTaskWithStreaming(scriptPath, taskName string) { job := t.jobManager.CreateJob(jobID, taskName, absPath) if err := t.jobManager.StartJob(job.ID); err != nil { // Unexpected failure — clean up the orphan and show options - t.jobManager.RemoveJob(job.ID) + if err := t.jobManager.RemoveJob(job.ID); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to remove failed job %s: %v\n", job.ID, err) + } t.showExecutionOptions(absPath, taskName) return } @@ -1231,180 +1268,185 @@ func (t *TUI) Stop() { }) } -// startSearch opens a compact centered modal for searching tasks across all categories. -func (t *TUI) startSearch() { - prevFocus := t.categoryPane - if t.app.GetFocus() == t.taskPane { - prevFocus = t.taskPane - } +// searchState holds mutable state for an active search modal session. +type searchState struct { + tui *TUI + results []SearchResult + continuations []bool + resultIdx []int + resultList *tview.List + inputField *tview.InputField + prevFocus *tview.List +} - var results []SearchResult +// isContinuation reports whether list item at idx is a wrapped description line. +func (s *searchState) isContinuation(idx int) bool { + return idx >= 0 && idx < len(s.continuations) && s.continuations[idx] +} - closeModal := func() { - t.pages.RemovePage("search") - t.setActiveFocus(prevFocus) +// resultIndex maps a list item index to its results[] index, or -1. +func (s *searchState) resultIndex(idx int) int { + if idx < 0 || idx >= len(s.resultIdx) { + return -1 } + return s.resultIdx[idx] +} - inputField := tview.NewInputField(). - SetLabel(t.str.SearchLabel). - SetFieldWidth(0) - - resultList := tview.NewList().ShowSecondaryText(false) - - // searchContinuations[i] == true means list item i is a wrapped description line, not a result - var searchContinuations []bool - // searchResultIdx[i] is the results[] index for list item i (-1 for continuation items) - var searchResultIdx []int - - isCont := func(idx int) bool { - return idx >= 0 && idx < len(searchContinuations) && searchContinuations[idx] +// moveDown advances past continuation lines. +func (s *searchState) moveDown() { + cur := s.resultList.GetCurrentItem() + count := s.resultList.GetItemCount() + next := cur + 1 + for next < count && s.isContinuation(next) { + next++ } - resultForIdx := func(idx int) int { - if idx < 0 || idx >= len(searchResultIdx) { - return -1 - } - return searchResultIdx[idx] + if next < count { + s.resultList.SetCurrentItem(next) } +} - moveSearchDown := func() { - cur := resultList.GetCurrentItem() - count := resultList.GetItemCount() - next := cur + 1 - for next < count && isCont(next) { - next++ - } - if next < count { - resultList.SetCurrentItem(next) - } +// moveUp retreats past continuation lines, returning false if at top. +func (s *searchState) moveUp() bool { + cur := s.resultList.GetCurrentItem() + prev := cur - 1 + for prev >= 0 && s.isContinuation(prev) { + prev-- } - moveSearchUp := func() { - cur := resultList.GetCurrentItem() - prev := cur - 1 - for prev >= 0 && isCont(prev) { - prev-- - } - if prev >= 0 { - resultList.SetCurrentItem(prev) - } + if prev < 0 { + s.tui.app.SetFocus(s.inputField) + return false } + s.resultList.SetCurrentItem(prev) + return true +} - updateResults := func(query string) { - resultList.Clear() - searchContinuations = searchContinuations[:0] - searchResultIdx = searchResultIdx[:0] - results = t.searchAllCategories(query) - - _, _, listWidth, _ := resultList.GetInnerRect() - if listWidth <= 0 { - listWidth = 36 // fallback before first draw (40% of 80col - borders) +// populateSearchResults clears and repopulates the result list for query. +func (s *searchState) populateSearchResults(query string) { + s.resultList.Clear() + s.continuations = s.continuations[:0] + s.resultIdx = s.resultIdx[:0] + s.results = s.tui.searchAllCategories(query) + + _, _, listWidth, _ := s.resultList.GetInnerRect() + if listWidth <= 0 { + listWidth = 36 + } + + for i, r := range s.results { + header := fmt.Sprintf("%s [%s]", r.Task.Name, r.CategoryName) + s.resultList.AddItem(header, "", 0, nil) + s.continuations = append(s.continuations, false) + s.resultIdx = append(s.resultIdx, i) + for _, line := range wrapText(r.Task.Description, listWidth) { + s.resultList.AddItem("[green] "+line+"[white]", "", 0, nil) + s.continuations = append(s.continuations, true) + s.resultIdx = append(s.resultIdx, i) } + } +} - for i, r := range results { - header := fmt.Sprintf("%s [%s]", r.Task.Name, r.CategoryName) - resultList.AddItem(header, "", 0, nil) - searchContinuations = append(searchContinuations, false) - searchResultIdx = append(searchResultIdx, i) - for _, line := range wrapText(r.Task.Description, listWidth) { - resultList.AddItem("[green] "+line+"[white]", "", 0, nil) - searchContinuations = append(searchContinuations, true) - searchResultIdx = append(searchResultIdx, i) - } +// inputCapture handles key events on the search input field. +func (s *searchState) inputCapture(event *tcell.EventKey) *tcell.EventKey { + switch event.Key() { + case tcell.KeyDown, tcell.KeyEnter: + if s.resultList.GetItemCount() > 0 { + s.tui.app.SetFocus(s.resultList) + s.resultList.SetCurrentItem(0) } + return nil + case tcell.KeyEscape: + s.close() + return nil } + return event +} - inputField.SetChangedFunc(updateResults) - - inputField.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { - switch event.Key() { - case tcell.KeyDown, tcell.KeyEnter: - if resultList.GetItemCount() > 0 { - t.app.SetFocus(resultList) - resultList.SetCurrentItem(0) - } +// listCapture handles key events on the search result list. +func (s *searchState) listCapture(event *tcell.EventKey) *tcell.EventKey { + switch event.Key() { + case tcell.KeyEscape: + s.close() + return nil + case tcell.KeyUp: + s.moveUp() + return nil + case tcell.KeyDown: + s.moveDown() + return nil + case tcell.KeyRune: + switch event.Rune() { + case 'j': + s.moveDown() return nil - case tcell.KeyEscape: - closeModal() + case 'k': + s.moveUp() return nil } - return event - }) + } + return event +} - resultList.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { - switch event.Key() { - case tcell.KeyEscape: - closeModal() - return nil - case tcell.KeyUp: - cur := resultList.GetCurrentItem() - prev := cur - 1 - for prev >= 0 && isCont(prev) { - prev-- - } - if prev < 0 { - t.app.SetFocus(inputField) - return nil - } - moveSearchUp() - return nil - case tcell.KeyDown: - moveSearchDown() - return nil - case tcell.KeyRune: - switch event.Rune() { - case 'j': - moveSearchDown() - return nil - case 'k': - cur := resultList.GetCurrentItem() - prev := cur - 1 - for prev >= 0 && isCont(prev) { - prev-- - } - if prev < 0 { - t.app.SetFocus(inputField) - return nil - } - moveSearchUp() - return nil - } - } - return event - }) +// selected handles activation of a search result. +func (s *searchState) selected(index int, _, _ string, _ rune) { + ri := s.resultIndex(index) + if ri < 0 || ri >= len(s.results) { + return + } + r := s.results[ri] + s.close() + s.tui.currentCategory = r.CategoryName + if len(r.Task.SubTasks) > 0 { + s.tui.showSubTaskMenu(r.Task) + } else { + s.tui.executeTaskWithStreaming(r.Task.Script, r.Task.Name) + } +} - resultList.SetSelectedFunc(func(index int, _, _ string, _ rune) { - ri := resultForIdx(index) - if ri < 0 || ri >= len(results) { - return - } - r := results[ri] - closeModal() - t.currentCategory = r.CategoryName - if len(r.Task.SubTasks) > 0 { - t.showSubTaskMenu(r.Task) - } else { - t.executeTaskWithStreaming(r.Task.Script, r.Task.Name) - } - }) +// close dismisses the search modal. +func (s *searchState) close() { + s.tui.pages.RemovePage("search") + s.tui.setActiveFocus(s.prevFocus) +} + +// buildSearchModal constructs the centered search modal layout. +func (s *searchState) buildSearchModal() tview.Primitive { + s.inputField.SetChangedFunc(s.populateSearchResults) + s.inputField.SetInputCapture(s.inputCapture) + s.resultList.SetInputCapture(s.listCapture) + s.resultList.SetSelectedFunc(s.selected) - // Content box: input on top, results list below contentBox := tview.NewFlex().SetDirection(tview.FlexRow). - AddItem(inputField, 3, 0, true). - AddItem(resultList, 0, 1, false) - contentBox.SetBorder(true).SetTitle(t.str.PaneTitleSearch) + AddItem(s.inputField, 3, 0, true). + AddItem(s.resultList, 0, 1, false) + contentBox.SetBorder(true).SetTitle(s.tui.str.PaneTitleSearch) - // Center: 40% wide (3:4:3), 60% tall (1:3:1) centerRow := tview.NewFlex().SetDirection(tview.FlexColumn). AddItem(nil, 0, 3, false). AddItem(contentBox, 0, 4, true). AddItem(nil, 0, 3, false) - modal := tview.NewFlex().SetDirection(tview.FlexRow). + return tview.NewFlex().SetDirection(tview.FlexRow). AddItem(nil, 0, 1, false). AddItem(centerRow, 0, 3, true). AddItem(nil, 0, 1, false) +} - t.pages.AddPage("search", modal, true, true) - t.app.SetFocus(inputField) +// startSearch opens a compact centered modal for searching tasks across all categories. +func (t *TUI) startSearch() { + prevFocus := t.categoryPane + if t.app.GetFocus() == t.taskPane { + prevFocus = t.taskPane + } + + s := &searchState{ + tui: t, + prevFocus: prevFocus, + inputField: tview.NewInputField().SetLabel(t.str.SearchLabel).SetFieldWidth(0), + resultList: tview.NewList().ShowSecondaryText(false), + } + + t.pages.AddPage("search", s.buildSearchModal(), true, true) + t.app.SetFocus(s.inputField) } // showHelp displays help information @@ -1424,17 +1466,18 @@ func (t *TUI) updateInfoPanel() { current := t.app.GetFocus() var content strings.Builder - if current == t.categoryPane { + switch current { + case t.categoryPane: content.WriteString(t.str.InfoCatLine1) content.WriteString(t.str.InfoCatLine2) - } else if current == t.taskPane { + case t.taskPane: if t.currentCategory != "" { - content.WriteString(fmt.Sprintf(t.str.FmtInfoTaskLine1, t.currentCategory)) + fmt.Fprintf(&content, t.str.FmtInfoTaskLine1, t.currentCategory) } else { content.WriteString(t.str.InfoTaskNoCatLine1) } content.WriteString(t.str.InfoGlobalLine) - } else { + default: content.WriteString(t.str.InfoDefaultLine1) content.WriteString(t.str.InfoDefaultLine2) } diff --git a/internal/workflow/workflow.go b/internal/workflow/workflow.go index 0e59974..b32e4bb 100644 --- a/internal/workflow/workflow.go +++ b/internal/workflow/workflow.go @@ -2,6 +2,7 @@ package workflow import ( "fmt" + "os" "strings" "sync" "time" @@ -471,7 +472,9 @@ func (we *WorkflowEngine) executeScriptStep(workflow *Workflow, step *WorkflowSt step.mu.Unlock() return false, j.GetError() case <-timer.C: - we.jobManager.CancelJob(job.ID) + if err := we.jobManager.CancelJob(job.ID); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to cancel timed-out job %s: %v\n", job.ID, err) + } return false, fmt.Errorf("step timed out after %v", timeout) } }