diff --git a/agent/backend/devicediscovery/device_discovery.go b/agent/backend/devicediscovery/device_discovery.go index ebf4152..19d4400 100644 --- a/agent/backend/devicediscovery/device_discovery.go +++ b/agent/backend/devicediscovery/device_discovery.go @@ -78,7 +78,9 @@ func (d *deviceDiscoveryBackend) Configure(logger *slog.Logger, repo policies.Po if d.apiHost, prs = config["host"].(string); !prs { d.apiHost = defaultAPIHost } - if d.apiPort, prs = config["port"].(string); !prs { + if port, prs := config["port"]; prs { + d.apiPort = fmt.Sprintf("%v", port) + } else { d.apiPort = defaultAPIPort } @@ -218,6 +220,13 @@ func (d *deviceDiscoveryBackend) Start(ctx context.Context, cancelFunc context.C var version string var readinessErr error for backoff := range readinessBackoff { + if status := d.proc.Status(); status.Complete { + err := d.proc.Stop() + if err != nil { + d.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("device-discovery process ended unexpectedly, check log") + } version, readinessErr = d.Version() if readinessErr == nil { d.logger.Info("device-discovery readiness ok, got version ", diff --git a/agent/backend/devicediscovery/device_discovery_test.go b/agent/backend/devicediscovery/device_discovery_test.go index 9c1f066..68068c7 100644 --- a/agent/backend/devicediscovery/device_discovery_test.go +++ b/agent/backend/devicediscovery/device_discovery_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "net/url" "os" + "path" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netboxlabs/orb-agent/agent/backend" "github.com/netboxlabs/orb-agent/agent/backend/devicediscovery" @@ -21,265 +23,195 @@ import ( ) type StatusResponse struct { - StartTime string `json:"start_time"` Version string `json:"version"` + StartTime string `json:"start_time"` UpTime float64 `json:"up_time"` } func TestDeviceDiscoveryBackendStart(t *testing.T) { - // Create server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { + switch { + case r.URL.Path == "/api/v1/status": response := StatusResponse{ Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if r.URL.Path == "/api/v1/capabilities" { + require.NoError(t, json.NewEncoder(w).Encode(response)) + case r.URL.Path == "/api/v1/capabilities": capabilities := map[string]any{ "capability": true, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(capabilities) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if strings.Contains(r.URL.Path, "/api/v1/policies") { - switch r.Method { - case http.MethodPost: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy applied successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case http.MethodDelete: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy removed successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } else { + require.NoError(t, json.NewEncoder(w).Encode(capabilities)) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() - // Parse server URL serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + require.NoError(t, err) + + createExecutable(t, "device-discovery") - // Create logger logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - // Create a mock repository repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) - // Create a mock command mockCmd := &mocks.MockCmd{} mocks.SetupSuccessfulProcess(mockCmd, 12345) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - // Assert that the correct parameters were passed + overrideNewCmdOptions(t, mockCmd, func(options backend.CmdOptions, name string, args []string) { assert.Equal(t, "device-discovery", name, "Expected command name to be device-discovery") - assert.Contains(t, args, "--port", "Expected args to contain port") - assert.Contains(t, args, "--host", "Expected args to contain host") assert.False(t, options.Buffered, "Expected buffered to be false") assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } + assert.Contains(t, args, "--host", "Expected args to contain host flag") + assert.Contains(t, args, serverURL.Hostname(), "Expected args to contain host value") + assert.Contains(t, args, "--port", "Expected args to contain port flag") + assert.Contains(t, args, serverURL.Port(), "Expected args to contain port value") + assert.Contains(t, args, "--diode-target", "Expected args to contain diode target flag") + assert.Contains(t, args, "device-target", "Expected args to contain diode target") + assert.Contains(t, args, "--diode-client-id", "Expected args to contain diode client id flag") + assert.Contains(t, args, "device-client", "Expected args to contain diode client id") + assert.Contains(t, args, "--diode-client-secret", "Expected args to contain diode client secret flag") + assert.Contains(t, args, "device-secret", "Expected args to contain diode client secret") + assert.Contains(t, args, "--diode-app-name-prefix", "Expected args to contain diode app name prefix flag") + assert.Contains(t, args, "device-agent", "Expected args to contain diode app name prefix") + assert.Contains(t, args, "--otel-endpoint", "Expected args to contain otel endpoint flag") + assert.Contains(t, args, "collector:4317", "Expected args to contain otel endpoint value") + }) assert.True(t, devicediscovery.Register(), "Failed to register DeviceDiscovery backend") - assert.True(t, backend.HaveBackend("device_discovery"), "Failed to get DeviceDiscovery backend") be := backend.GetBackend("device_discovery") assert.Equal(t, backend.Unknown, be.GetInitialState()) - // Configure backend + commons := config.BackendCommons{} + commons.Otlp.Grpc = "collector:4317" + commons.Diode.Target = "default-target" + commons.Diode.ClientID = "default-client" + commons.Diode.ClientSecret = "default-secret" + commons.Diode.AgentName = "default-agent" + commons.Diode.DryRunOutputDir = "/tmp/default" + err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) + "host": serverURL.Hostname(), + "port": serverURL.Port(), + "target": "device-target", + "client_id": "device-client", + "client_secret": "device-secret", + "agent_name": "device-agent", + "dry_run": false, + "dry_run_output_dir": "/tmp/device", + }, commons) + require.NoError(t, err) - // Start the backend ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + + require.NoError(t, be.Start(ctx, cancel)) - // Assert successful start - assert.NoError(t, err) + startTime := be.GetStartTime() + assert.False(t, startTime.IsZero(), "Expected start time to be set") - // Get Running status status, _, err := be.GetRunningStatus() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, backend.Running, status, "Expected backend to be running") - // Get capabilities capabilities, err := be.GetCapabilities() - assert.NoError(t, err) - assert.Equal(t, capabilities["capability"], true, "Expected capability to be true") + require.NoError(t, err) + assert.Equal(t, true, capabilities["capability"], "Expected capability to be true") + + version, err := be.Version() + require.NoError(t, err) + assert.Equal(t, "1.3.5", version, "Expected version to match response") data := policies.PolicyData{ ID: "dummy-policy-id", Name: "dummy-policy-name", Data: map[string]any{"key": "value"}, } - // Apply policy - err = be.ApplyPolicy(data, false) - assert.NoError(t, err) + require.NoError(t, be.ApplyPolicy(data, false)) - // Update policy - err = be.ApplyPolicy(data, true) - assert.NoError(t, err) + updatedData := policies.PolicyData{ + ID: data.ID, + Name: "dummy-policy-updated", + Data: map[string]any{"key": "value"}, + PreviousPolicyData: &policies.PolicyData{ + Name: data.Name, + }, + } + require.NoError(t, be.ApplyPolicy(updatedData, true)) - // Assert restart - err = be.FullReset(ctx) - assert.NoError(t, err) + require.NoError(t, be.FullReset(ctx)) - // Verify expectations mockCmd.AssertExpectations(t) } func TestDeviceDiscoveryBackendCompleted(t *testing.T) { - // Create a mock command that simulates a failure mockCmd := &mocks.MockCmd{} mocks.SetupCompletedProcess(mockCmd, 0, nil) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, _ ...string) backend.Commander { - return mockCmd - } - assert.True(t, devicediscovery.Register(), "Failed to register DeviceDiscovery backend") + overrideNewCmdOptions(t, mockCmd, nil) + assert.True(t, devicediscovery.Register(), "Failed to register DeviceDiscovery backend") assert.True(t, backend.HaveBackend("device_discovery"), "Failed to get DeviceDiscovery backend") be := backend.GetBackend("device_discovery") - // Configure backend with invalid parameters - err := be.Configure(slog.Default(), nil, map[string]any{ + require.NoError(t, be.Configure(slog.Default(), nil, map[string]any{ "host": "invalid-host", - }, config.BackendCommons{}) - assert.NoError(t, err) + }, config.BackendCommons{})) ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + err := be.Start(ctx, cancel) assert.Error(t, err) + + mockCmd.AssertExpectations(t) } -func TestDeviceDiscoveryBackendDryRun(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch { - case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() +func createExecutable(t *testing.T, name string) { + t.Helper() - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + tempDir := t.TempDir() + binaryPath := path.Join(tempDir, name) - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) + file, err := os.Create(binaryPath) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, os.Chmod(binaryPath, 0o755)) - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) + originalPath := os.Getenv("PATH") + t.Setenv("PATH", tempDir+string(os.PathListSeparator)+originalPath) +} - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() +func overrideNewCmdOptions(t *testing.T, cmd backend.Commander, assertFn func(options backend.CmdOptions, name string, args []string)) { + t.Helper() + original := backend.NewCmdOptions backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - assert.Equal(t, "device-discovery", name, "Expected command name to be device-discovery") - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.False(t, options.Buffered, "Expected buffered to be false") - assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } - - assert.True(t, devicediscovery.Register()) - be := backend.GetBackend("device_discovery") - - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, + if assertFn != nil { + assertFn(options, name, args) + } + return cmd } - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, beCommons) - assert.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - assert.False(t, be.GetStartTime().IsZero()) - - err = be.RemovePolicy(policies.PolicyData{ID: "1", Name: "policy", Data: map[string]any{"k": "v"}}) - assert.NoError(t, err) - - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) - - mockCmd.AssertExpectations(t) + t.Cleanup(func() { + backend.NewCmdOptions = original + }) } diff --git a/agent/backend/networkdiscovery/network_discovery.go b/agent/backend/networkdiscovery/network_discovery.go index 499ecb1..e772fd5 100644 --- a/agent/backend/networkdiscovery/network_discovery.go +++ b/agent/backend/networkdiscovery/network_discovery.go @@ -79,7 +79,9 @@ func (d *networkDiscoveryBackend) Configure(logger *slog.Logger, repo policies.P if d.apiHost, prs = config["host"].(string); !prs { d.apiHost = defaultAPIHost } - if d.apiPort, prs = config["port"].(string); !prs { + if port, prs := config["port"]; prs { + d.apiPort = fmt.Sprintf("%v", port) + } else { d.apiPort = defaultAPIPort } @@ -238,6 +240,13 @@ func (d *networkDiscoveryBackend) Start(ctx context.Context, cancelFunc context. var version string var readinessErr error for backoff := range readinessBackoff { + if status := d.proc.Status(); status.Complete { + err := d.proc.Stop() + if err != nil { + d.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("network-discovery process ended unexpectedly, check log") + } version, readinessErr = d.Version() if readinessErr == nil { d.logger.Info("network-discovery readiness ok, got version ", diff --git a/agent/backend/networkdiscovery/network_discovery_test.go b/agent/backend/networkdiscovery/network_discovery_test.go index 8c8b4f3..d64a13b 100644 --- a/agent/backend/networkdiscovery/network_discovery_test.go +++ b/agent/backend/networkdiscovery/network_discovery_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "net/url" "os" + "path" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netboxlabs/orb-agent/agent/backend" "github.com/netboxlabs/orb-agent/agent/backend/mocks" @@ -21,461 +23,198 @@ import ( ) type StatusResponse struct { - StartTime string `json:"start_time"` Version string `json:"version"` + StartTime string `json:"start_time"` UpTime float64 `json:"up_time"` } func TestNetworkDiscoveryBackendStart(t *testing.T) { - // Create server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { + switch { + case r.URL.Path == "/api/v1/status": response := StatusResponse{ Version: "1.3.4", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if r.URL.Path == "/api/v1/capabilities" { + require.NoError(t, json.NewEncoder(w).Encode(response)) + case r.URL.Path == "/api/v1/capabilities": capabilities := map[string]any{ "capability": true, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(capabilities) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if strings.Contains(r.URL.Path, "/api/v1/policies") { - switch r.Method { - case http.MethodPost: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy applied successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case http.MethodDelete: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy removed successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } else { + require.NoError(t, json.NewEncoder(w).Encode(capabilities)) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() - // Parse server URL serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + require.NoError(t, err) + + createExecutable(t, "network-discovery") - // Create logger logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - // Create a mock repository repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) - // Create a mock command mockCmd := &mocks.MockCmd{} mocks.SetupSuccessfulProcess(mockCmd, 12345) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - // Assert that the correct parameters were passed + overrideNewCmdOptions(t, mockCmd, func(options backend.CmdOptions, name string, args []string) { assert.Equal(t, "network-discovery", name, "Expected command name to be network-discovery") - assert.Contains(t, args, "--port", "Expected args to contain port") - assert.Contains(t, args, "--host", "Expected args to contain host") assert.False(t, options.Buffered, "Expected buffered to be false") assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } + assert.Contains(t, args, "--host", "Expected args to contain host flag") + assert.Contains(t, args, serverURL.Hostname(), "Expected args to contain host value") + assert.Contains(t, args, "--port", "Expected args to contain port flag") + assert.Contains(t, args, serverURL.Port(), "Expected args to contain port value") + assert.Contains(t, args, "--diode-target", "Expected args to contain diode target flag") + assert.Contains(t, args, "network-target", "Expected args to contain diode target") + assert.Contains(t, args, "--diode-client-id", "Expected args to contain diode client id flag") + assert.Contains(t, args, "network-client", "Expected args to contain diode client id") + assert.Contains(t, args, "--diode-client-secret", "Expected args to contain diode client secret flag") + assert.Contains(t, args, "network-secret", "Expected args to contain diode client secret") + assert.Contains(t, args, "--diode-app-name-prefix", "Expected args to contain diode app name prefix flag") + assert.Contains(t, args, "network-agent", "Expected args to contain diode app name prefix") + assert.Contains(t, args, "--log-level", "Expected args to contain log level flag") + assert.Contains(t, args, "debug", "Expected args to contain log level value") + assert.Contains(t, args, "--otel-endpoint", "Expected args to contain otel endpoint flag") + assert.Contains(t, args, "collector:4317", "Expected args to contain otel endpoint value") + }) assert.True(t, networkdiscovery.Register(), "Failed to register NetworkDiscovery backend") - assert.True(t, backend.HaveBackend("network_discovery"), "Failed to get NetworkDiscovery backend") be := backend.GetBackend("network_discovery") assert.Equal(t, backend.Unknown, be.GetInitialState()) - // Configure backend + commons := config.BackendCommons{} + commons.Otlp.Grpc = "collector:4317" + commons.Diode.Target = "default-target" + commons.Diode.ClientID = "default-client" + commons.Diode.ClientSecret = "default-secret" + commons.Diode.AgentName = "default-agent" + commons.Diode.DryRunOutputDir = "/tmp/default" + err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) + "host": serverURL.Hostname(), + "port": serverURL.Port(), + "target": "network-target", + "client_id": "network-client", + "client_secret": "network-secret", + "agent_name": "network-agent", + "log_level": "debug", + "dry_run": false, + "dry_run_output_dir": "/tmp/network", + }, commons) + require.NoError(t, err) - // Start the backend ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + + require.NoError(t, be.Start(ctx, cancel)) - // Assert successful start - assert.NoError(t, err) + startTime := be.GetStartTime() + assert.False(t, startTime.IsZero(), "Expected start time to be set") - // Get Running status status, _, err := be.GetRunningStatus() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, backend.Running, status, "Expected backend to be running") - // Get capabilities capabilities, err := be.GetCapabilities() - assert.NoError(t, err) - assert.Equal(t, capabilities["capability"], true, "Expected capability to be true") + require.NoError(t, err) + assert.Equal(t, true, capabilities["capability"], "Expected capability to be true") + + version, err := be.Version() + require.NoError(t, err) + assert.Equal(t, "1.3.4", version, "Expected version to match response") data := policies.PolicyData{ ID: "dummy-policy-id", Name: "dummy-policy-name", Data: map[string]any{"key": "value"}, } - // Apply policy - err = be.ApplyPolicy(data, false) - assert.NoError(t, err) + require.NoError(t, be.ApplyPolicy(data, false)) - // Update policy - err = be.ApplyPolicy(data, true) - assert.NoError(t, err) + updatedData := policies.PolicyData{ + ID: data.ID, + Name: "dummy-policy-updated", + Data: map[string]any{"key": "value"}, + PreviousPolicyData: &policies.PolicyData{ + Name: data.Name, + }, + } + require.NoError(t, be.ApplyPolicy(updatedData, true)) - // Assert restart - err = be.FullReset(ctx) - assert.NoError(t, err) + require.NoError(t, be.FullReset(ctx)) - // Verify expectations mockCmd.AssertExpectations(t) } func TestNetworkDiscoveryBackendCompleted(t *testing.T) { - // Create a mock command that simulates a failure mockCmd := &mocks.MockCmd{} mocks.SetupCompletedProcess(mockCmd, 0, nil) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, _ ...string) backend.Commander { - return mockCmd - } - assert.True(t, networkdiscovery.Register(), "Failed to register NetworkDiscovery backend") + overrideNewCmdOptions(t, mockCmd, nil) + assert.True(t, networkdiscovery.Register(), "Failed to register NetworkDiscovery backend") assert.True(t, backend.HaveBackend("network_discovery"), "Failed to get NetworkDiscovery backend") be := backend.GetBackend("network_discovery") - // Configure backend with invalid parameters - err := be.Configure(slog.Default(), nil, map[string]any{ + require.NoError(t, be.Configure(slog.Default(), nil, map[string]any{ "host": "invalid-host", - }, config.BackendCommons{}) - assert.NoError(t, err) + }, config.BackendCommons{})) ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + err := be.Start(ctx, cancel) assert.Error(t, err) -} - -func TestNetworkDiscoveryBackendDryRun(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch { - case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) - - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - assert.Equal(t, "network-discovery", name, "Expected command name to be network-discovery") - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.False(t, options.Buffered, "Expected buffered to be false") - assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } - - assert.True(t, networkdiscovery.Register()) - be := backend.GetBackend("network_discovery") - - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, - } - - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, beCommons) - assert.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - assert.False(t, be.GetStartTime().IsZero()) - - err = be.RemovePolicy(policies.PolicyData{ID: "1", Name: "policy", Data: map[string]any{"k": "v"}}) - assert.NoError(t, err) - - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) mockCmd.AssertExpectations(t) } -func TestNetworkDiscoveryLogLevel(t *testing.T) { - // Create a test server for all log level tests - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { - response := StatusResponse{ - Version: "1.3.4", - StartTime: "2023-10-01T12:00:00Z", - UpTime: 123.456, - } - _ = json.NewEncoder(w).Encode(response) - } - })) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) - - // Test default log level - t.Run("LogLevelNotSet", func(t *testing.T) { - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, networkdiscovery.Register()) - be := backend.GetBackend("network_discovery") - - // Configure without log_level - should use default - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that log level is not set in command args - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.NotContains(t, args, "--log-level") - return mockCmd - } - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - // Stop the backend - _ = be.Stop(ctx) - mockCmd.AssertExpectations(t) - }) +func createExecutable(t *testing.T, name string) { + t.Helper() - // Test custom log level - t.Run("CustomLogLevel", func(t *testing.T) { - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, networkdiscovery.Register()) - be := backend.GetBackend("network_discovery") - - customLogLevel := "debug" - // Configure with custom log_level - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - "log_level": customLogLevel, - }, config.BackendCommons{}) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that custom log level is used in command args - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.Contains(t, args, "--log-level") - // Find the index of --log-level and check the next argument - for i, arg := range args { - if arg == "--log-level" && i+1 < len(args) { - assert.Equal(t, customLogLevel, args[i+1]) - break - } - } - return mockCmd - } + tempDir := t.TempDir() + binaryPath := path.Join(tempDir, name) - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) + file, err := os.Create(binaryPath) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, os.Chmod(binaryPath, 0o755)) - // Stop the backend - _ = be.Stop(ctx) - mockCmd.AssertExpectations(t) - }) + originalPath := os.Getenv("PATH") + t.Setenv("PATH", tempDir+string(os.PathListSeparator)+originalPath) +} - // Test dry run mode includes log level - t.Run("DryRunWithLogLevel", func(t *testing.T) { - // Create a test server for dry run tests (even though dry run doesn't use HTTP, the backend still tries readiness checks) - dryRunServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch { - case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer dryRunServer.Close() - - dryRunServerURL, err := url.Parse(dryRunServer.URL) - assert.NoError(t, err) - - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, networkdiscovery.Register()) - be := backend.GetBackend("network_discovery") - - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, - } +func overrideNewCmdOptions(t *testing.T, cmd backend.Commander, assertFn func(options backend.CmdOptions, name string, args []string)) { + t.Helper() - customLogLevel := "debug" - err = be.Configure(logger, repo, map[string]any{ - "host": dryRunServerURL.Hostname(), - "port": dryRunServerURL.Port(), - "log_level": customLogLevel, - }, beCommons) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that log level IS included in dry run mode - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.Contains(t, args, "--log-level") - // Find the index of --log-level and check the next argument - for i, arg := range args { - if arg == "--log-level" && i+1 < len(args) { - assert.Equal(t, customLogLevel, args[i+1]) - break - } - } - // Verify dry run specific args are also present - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - // Verify non-dry-run args are NOT present - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.NotContains(t, args, "--diode-target") - return mockCmd + original := backend.NewCmdOptions + backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { + if assertFn != nil { + assertFn(options, name, args) } + return cmd + } - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - // Stop the backend - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) - mockCmd.AssertExpectations(t) + t.Cleanup(func() { + backend.NewCmdOptions = original }) } diff --git a/agent/backend/opentelemetryinfinity/opentelemetry_infinity.go b/agent/backend/opentelemetryinfinity/opentelemetry_infinity.go index 4d72b4b..ce9c7d9 100644 --- a/agent/backend/opentelemetryinfinity/opentelemetry_infinity.go +++ b/agent/backend/opentelemetryinfinity/opentelemetry_infinity.go @@ -73,10 +73,11 @@ func (o *openTelemetryBackend) Configure(logger *slog.Logger, repo policies.Poli if o.apiHost, prs = config["host"].(string); !prs { o.apiHost = defaultAPIHost } - if o.apiPort, prs = config["port"].(string); !prs { + if port, prs := config["port"]; prs { + o.apiPort = fmt.Sprintf("%v", port) + } else { o.apiPort = defaultAPIPort } - o.agentLabels = common.Otlp.AgentLabels return nil @@ -168,10 +169,17 @@ func (o *openTelemetryBackend) Start(ctx context.Context, cancelFunc context.Can var version string var readinessErr error for backoff := range readinessBackoff { + if status := o.proc.Status(); status.Complete { + err := o.proc.Stop() + if err != nil { + o.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("opentelemetry infinity process ended unexpectedly, check log") + } version, readinessErr = o.Version() if readinessErr == nil { o.logger.Info("opentelemetry infinity readiness ok, got version ", - slog.String("device_discovery_version", version)) + slog.String("opentelemetry_infinity_version", version)) break } backoffDuration := time.Duration(backoff) * time.Second diff --git a/agent/backend/pktvisor/pktvisor.go b/agent/backend/pktvisor/pktvisor.go index 93507e0..87cf959 100644 --- a/agent/backend/pktvisor/pktvisor.go +++ b/agent/backend/pktvisor/pktvisor.go @@ -193,6 +193,13 @@ func (p *pktvisorBackend) Start(ctx context.Context, cancelFunc context.CancelFu var readinessError error for backoff := range readinessBackoff { + if status := p.proc.Status(); status.Complete { + err := p.proc.Stop() + if err != nil { + p.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("pktvisor process ended unexpectedly, check log") + } var appMetrics AppInfo url := fmt.Sprintf("%s://%s:%s/api/v1/metrics/app", p.adminAPIProtocol, p.adminAPIHost, p.adminAPIPort) readinessError = backend.CommonRequest("pktvisor", p.proc, p.logger, url, &appMetrics, http.MethodGet, @@ -266,9 +273,7 @@ func (p *pktvisorBackend) Configure(logger *slog.Logger, repo policies.PolicyRep } configSection[key] = value case "port": - if v, ok := value.(string); ok { - p.adminAPIPort = v - } + p.adminAPIPort = fmt.Sprintf("%v", value) configSection[key] = value case "taps": visorConfig["taps"] = value diff --git a/agent/backend/pktvisor/pktvisor_test.go b/agent/backend/pktvisor/pktvisor_test.go index 8c1d288..8546809 100644 --- a/agent/backend/pktvisor/pktvisor_test.go +++ b/agent/backend/pktvisor/pktvisor_test.go @@ -10,6 +10,7 @@ import ( "os" "path" "strings" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" @@ -29,190 +30,217 @@ type StatusResponse struct { } func TestPktvisorBackendStart(t *testing.T) { - // Create server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/metrics/app" { + switch { + case r.URL.Path == "/api/v1/metrics/app": var response pktvisor.AppInfo response.App.Version = "1.2.3" response.App.UpTimeMin = 42.5 w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if r.URL.Path == "/api/v1/taps" { - capabilities := map[string]any{ - "iface": "eth0", - } + require.NoError(t, json.NewEncoder(w).Encode(response)) + case r.URL.Path == "/api/v1/taps": w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(capabilities) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if strings.Contains(r.URL.Path, "/api/v1/policies") { - switch r.Method { - case http.MethodPost: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy applied successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case http.MethodDelete: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy removed successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } else { + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{"iface": "eth0"})) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() - // Create a temporary directory and file for the test - tempDir := t.TempDir() - binaryPath := path.Join(tempDir, "pktvisord") - dummyBinary, err := os.Create(binaryPath) - require.NoError(t, err) - err = dummyBinary.Close() - require.NoError(t, err) - - // Make the binary executable - err = os.Chmod(binaryPath, 0o755) - require.NoError(t, err) - - // Add our temp directory to the PATH - err = os.Setenv("PATH", tempDir+string(os.PathListSeparator)+os.Getenv("PATH")) + serverURL, err := url.Parse(server.URL) require.NoError(t, err) - // Parse server URL - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + createExecutable(t, "pktvisord") - // Create logger logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - // Create a mock repository repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) - // Create a mock command mockCmd := &mocks.MockCmd{} mocks.SetupSuccessfulProcess(mockCmd, 12345) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - // Assert that the correct parameters were passed + overrideNewCmdOptions(t, mockCmd, func(options backend.CmdOptions, name string, args []string) { assert.Equal(t, "pktvisord", name, "Expected command name to be pktvisord") - assert.Contains(t, args, "--admin-api", "Expected args to contain admin-api") assert.False(t, options.Buffered, "Expected buffered to be false") assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } + assert.Contains(t, args, "--admin-api", "Expected args to contain admin-api flag") + assert.Contains(t, args, "--config", "Expected args to contain config flag") + assert.Contains(t, args, "--otel", "Expected args to contain otel flag") + assert.Contains(t, args, "--otel-host", "Expected args to contain otel host flag") + assert.Contains(t, args, serverURL.Hostname(), "Expected args to contain otel host value") + assert.Contains(t, args, "--otel-port", "Expected args to contain otel port flag") + assert.Contains(t, args, serverURL.Port(), "Expected args to contain otel port value") + }) assert.True(t, pktvisor.Register(), "Failed to register Pktvisor backend") - assert.True(t, backend.HaveBackend("pktvisor"), "Failed to get Pktvisor backend") be := backend.GetBackend("pktvisor") assert.Equal(t, backend.Unknown, be.GetInitialState()) - // Configure backend + commons := config.BackendCommons{} + commons.Otlp.HTTP = server.URL + commons.Otlp.AgentLabels = map[string]string{"env": "test"} + err = be.Configure(logger, repo, map[string]any{ "host": serverURL.Hostname(), "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) + }, commons) + require.NoError(t, err) + + baseCtx := context.WithValue(context.Background(), config.ContextKey("agent_id"), "test-agent") + ctx, cancel := context.WithCancel(baseCtx) + defer cancel() - // Start the backend - ctx, cancel := context.WithCancel(context.Background()) err = be.Start(ctx, cancel) + require.NoError(t, err) - // Assert successful start - assert.NoError(t, err) + startTime := be.GetStartTime() + assert.False(t, startTime.IsZero(), "Expected start time to be set") - // Get Running status status, _, err := be.GetRunningStatus() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, backend.Running, status, "Expected backend to be running") - // Get capabilities capabilities, err := be.GetCapabilities() - assert.NoError(t, err) - assert.Equal(t, map[string]any{"iface": "eth0"}, capabilities["taps"], "Expected taps") + require.NoError(t, err) + assert.Equal(t, map[string]any{"iface": "eth0"}, capabilities["taps"], "Expected taps to match response") + + version, err := be.Version() + require.NoError(t, err) + assert.Equal(t, "1.2.3", version, "Expected version to match response") data := policies.PolicyData{ ID: "dummy-policy-id", Name: "dummy-policy-name", Data: map[string]any{"key": "value"}, } - // Apply policy - err = be.ApplyPolicy(data, false) - assert.NoError(t, err) + require.NoError(t, be.ApplyPolicy(data, false)) - // Update policy - err = be.ApplyPolicy(data, true) - assert.NoError(t, err) + updatedData := policies.PolicyData{ + ID: data.ID, + Name: "dummy-policy-updated", + Data: map[string]any{"key": "value"}, + PreviousPolicyData: &policies.PolicyData{ + Name: data.Name, + }, + } + require.NoError(t, be.ApplyPolicy(updatedData, true)) - // Assert restart - err = be.FullReset(ctx) - assert.NoError(t, err) + require.NoError(t, be.FullReset(ctx)) - // Verify expectations mockCmd.AssertExpectations(t) } -func TestPktvisorBackendCompleted(t *testing.T) { - // Create a mock command that simulates a failure +func TestPktvisorGetRunningStatusAPIFailure(t *testing.T) { + var metricsCalls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch { + case r.URL.Path == "/api/v1/metrics/app": + if metricsCalls.Add(1) == 1 { + var response pktvisor.AppInfo + response.App.Version = "1.2.3" + response.App.UpTimeMin = 1.5 + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(response)) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + _, err := w.Write([]byte(`{"error":"unavailable"}`)) + require.NoError(t, err) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{}`)) + require.NoError(t, err) + } + })) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + createExecutable(t, "pktvisord") + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + repo, err := policies.NewMemRepo() + require.NoError(t, err) + mockCmd := &mocks.MockCmd{} - mocks.SetupCompletedProcess(mockCmd, 0, nil) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, _ ...string) backend.Commander { - return mockCmd - } + mocks.SetupSuccessfulProcess(mockCmd, 54321) - assert.True(t, pktvisor.Register(), "Failed to register Pktvisor backend") + overrideNewCmdOptions(t, mockCmd, nil) - assert.True(t, backend.HaveBackend("pktvisor"), "Failed to get Pktvisor backend") + assert.True(t, pktvisor.Register(), "Failed to register Pktvisor backend") be := backend.GetBackend("pktvisor") - // Configure backend with invalid parameters - err := be.Configure(slog.Default(), nil, map[string]any{ - "host": "invalid-host", + err = be.Configure(logger, repo, map[string]any{ + "host": serverURL.Hostname(), + "port": serverURL.Port(), }, config.BackendCommons{}) - assert.NoError(t, err) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + require.NoError(t, be.Start(ctx, cancel)) + + status, message, err := be.GetRunningStatus() + assert.Equal(t, backend.BackendError, status, "Expected backend to report API failure") + assert.Equal(t, "process running, REST API unavailable", message) assert.Error(t, err) + require.NoError(t, be.Stop(ctx)) + + mockCmd.AssertExpectations(t) +} + +func createExecutable(t *testing.T, name string) { + t.Helper() + + tempDir := t.TempDir() + binaryPath := path.Join(tempDir, name) + + file, err := os.Create(binaryPath) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, os.Chmod(binaryPath, 0o755)) + + originalPath := os.Getenv("PATH") + t.Setenv("PATH", tempDir+string(os.PathListSeparator)+originalPath) +} + +func overrideNewCmdOptions(t *testing.T, cmd backend.Commander, assertFn func(options backend.CmdOptions, name string, args []string)) { + t.Helper() + + original := backend.NewCmdOptions + backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { + if assertFn != nil { + assertFn(options, name, args) + } + return cmd + } + + t.Cleanup(func() { + backend.NewCmdOptions = original + }) } diff --git a/agent/backend/snmpdiscovery/snmp_discovery.go b/agent/backend/snmpdiscovery/snmp_discovery.go index 8a66d83..33de73e 100644 --- a/agent/backend/snmpdiscovery/snmp_discovery.go +++ b/agent/backend/snmpdiscovery/snmp_discovery.go @@ -79,7 +79,9 @@ func (d *snmpDiscoveryBackend) Configure(logger *slog.Logger, repo policies.Poli if d.apiHost, prs = config["host"].(string); !prs { d.apiHost = defaultAPIHost } - if d.apiPort, prs = config["port"].(string); !prs { + if port, prs := config["port"]; prs { + d.apiPort = fmt.Sprintf("%v", port) + } else { d.apiPort = defaultAPIPort } @@ -232,6 +234,13 @@ func (d *snmpDiscoveryBackend) Start(ctx context.Context, cancelFunc context.Can var version string var readinessErr error for backoff := 1; backoff <= readinessBackoff; backoff++ { + if status := d.proc.Status(); status.Complete { + err := d.proc.Stop() + if err != nil { + d.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("snmp-discovery process ended unexpectedly, check log") + } version, readinessErr = d.Version() if readinessErr == nil { d.logger.Info("snmp-discovery readiness ok, got version ", diff --git a/agent/backend/snmpdiscovery/snmp_discovery_test.go b/agent/backend/snmpdiscovery/snmp_discovery_test.go index 82f917e..b0b3762 100644 --- a/agent/backend/snmpdiscovery/snmp_discovery_test.go +++ b/agent/backend/snmpdiscovery/snmp_discovery_test.go @@ -8,10 +8,12 @@ import ( "net/http/httptest" "net/url" "os" + "path" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netboxlabs/orb-agent/agent/backend" "github.com/netboxlabs/orb-agent/agent/backend/mocks" @@ -21,461 +23,198 @@ import ( ) type StatusResponse struct { - StartTime string `json:"start_time"` Version string `json:"version"` + StartTime string `json:"start_time"` UpTime float64 `json:"up_time"` } -func TestSNMPDiscoveryBackendStart(t *testing.T) { - // Create server +func TestSnmpDiscoveryBackendStart(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { + switch { + case r.URL.Path == "/api/v1/status": response := StatusResponse{ - Version: "1.3.4", + Version: "1.3.6", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if r.URL.Path == "/api/v1/capabilities" { + require.NoError(t, json.NewEncoder(w).Encode(response)) + case r.URL.Path == "/api/v1/capabilities": capabilities := map[string]any{ "capability": true, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(capabilities) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if strings.Contains(r.URL.Path, "/api/v1/policies") { - switch r.Method { - case http.MethodPost: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy applied successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case http.MethodDelete: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy removed successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } else { + require.NoError(t, json.NewEncoder(w).Encode(capabilities)) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() - // Parse server URL serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + require.NoError(t, err) + + createExecutable(t, "snmp-discovery") - // Create logger logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - // Create a mock repository repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) - // Create a mock command mockCmd := &mocks.MockCmd{} mocks.SetupSuccessfulProcess(mockCmd, 12345) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - // Assert that the correct parameters were passed + overrideNewCmdOptions(t, mockCmd, func(options backend.CmdOptions, name string, args []string) { assert.Equal(t, "snmp-discovery", name, "Expected command name to be snmp-discovery") - assert.Contains(t, args, "--port", "Expected args to contain port") - assert.Contains(t, args, "--host", "Expected args to contain host") assert.False(t, options.Buffered, "Expected buffered to be false") assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } - - assert.True(t, snmpdiscovery.Register(), "Failed to register SNMP Discovery backend") + assert.Contains(t, args, "--host", "Expected args to contain host flag") + assert.Contains(t, args, serverURL.Hostname(), "Expected args to contain host value") + assert.Contains(t, args, "--port", "Expected args to contain port flag") + assert.Contains(t, args, serverURL.Port(), "Expected args to contain port value") + assert.Contains(t, args, "--diode-target", "Expected args to contain diode target flag") + assert.Contains(t, args, "snmp-target", "Expected args to contain diode target") + assert.Contains(t, args, "--diode-client-id", "Expected args to contain diode client id flag") + assert.Contains(t, args, "snmp-client", "Expected args to contain diode client id") + assert.Contains(t, args, "--diode-client-secret", "Expected args to contain diode client secret flag") + assert.Contains(t, args, "snmp-secret", "Expected args to contain diode client secret") + assert.Contains(t, args, "--diode-app-name-prefix", "Expected args to contain diode app name prefix flag") + assert.Contains(t, args, "snmp-agent", "Expected args to contain diode app name prefix") + assert.Contains(t, args, "--log-level", "Expected args to contain log level flag") + assert.Contains(t, args, "debug", "Expected args to contain log level value") + assert.Contains(t, args, "--otel-endpoint", "Expected args to contain otel endpoint flag") + assert.Contains(t, args, "collector:4317", "Expected args to contain otel endpoint value") + }) - assert.True(t, backend.HaveBackend("snmp_discovery"), "Failed to get SNMP Discovery backend") + assert.True(t, snmpdiscovery.Register(), "Failed to register SnmpDiscovery backend") + assert.True(t, backend.HaveBackend("snmp_discovery"), "Failed to get SnmpDiscovery backend") be := backend.GetBackend("snmp_discovery") assert.Equal(t, backend.Unknown, be.GetInitialState()) - // Configure backend + commons := config.BackendCommons{} + commons.Otlp.Grpc = "collector:4317" + commons.Diode.Target = "default-target" + commons.Diode.ClientID = "default-client" + commons.Diode.ClientSecret = "default-secret" + commons.Diode.AgentName = "default-agent" + commons.Diode.DryRunOutputDir = "/tmp/default" + err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) + "host": serverURL.Hostname(), + "port": serverURL.Port(), + "target": "snmp-target", + "client_id": "snmp-client", + "client_secret": "snmp-secret", + "agent_name": "snmp-agent", + "log_level": "debug", + "dry_run": false, + "dry_run_output_dir": "/tmp/snmp", + }, commons) + require.NoError(t, err) - // Start the backend ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() - // Assert successful start - assert.NoError(t, err) + require.NoError(t, be.Start(ctx, cancel)) + + startTime := be.GetStartTime() + assert.False(t, startTime.IsZero(), "Expected start time to be set") - // Get Running status status, _, err := be.GetRunningStatus() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, backend.Running, status, "Expected backend to be running") - // Get capabilities capabilities, err := be.GetCapabilities() - assert.NoError(t, err) - assert.Equal(t, capabilities["capability"], true, "Expected capability to be true") + require.NoError(t, err) + assert.Equal(t, true, capabilities["capability"], "Expected capability to be true") + + version, err := be.Version() + require.NoError(t, err) + assert.Equal(t, "1.3.6", version, "Expected version to match response") data := policies.PolicyData{ ID: "dummy-policy-id", Name: "dummy-policy-name", Data: map[string]any{"key": "value"}, } - // Apply policy - err = be.ApplyPolicy(data, false) - assert.NoError(t, err) + require.NoError(t, be.ApplyPolicy(data, false)) - // Update policy - err = be.ApplyPolicy(data, true) - assert.NoError(t, err) + updatedData := policies.PolicyData{ + ID: data.ID, + Name: "dummy-policy-updated", + Data: map[string]any{"key": "value"}, + PreviousPolicyData: &policies.PolicyData{ + Name: data.Name, + }, + } + require.NoError(t, be.ApplyPolicy(updatedData, true)) - // Assert restart - err = be.FullReset(ctx) - assert.NoError(t, err) + require.NoError(t, be.FullReset(ctx)) - // Verify expectations mockCmd.AssertExpectations(t) } -func TestSNMPDiscoveryBackendCompleted(t *testing.T) { - // Create a mock command that simulates a failure +func TestSnmpDiscoveryBackendCompleted(t *testing.T) { mockCmd := &mocks.MockCmd{} mocks.SetupCompletedProcess(mockCmd, 0, nil) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, _ ...string) backend.Commander { - return mockCmd - } - assert.True(t, snmpdiscovery.Register(), "Failed to register SNMP Discovery backend") + overrideNewCmdOptions(t, mockCmd, nil) - assert.True(t, backend.HaveBackend("snmp_discovery"), "Failed to get SNMP Discovery backend") + assert.True(t, snmpdiscovery.Register(), "Failed to register SnmpDiscovery backend") + assert.True(t, backend.HaveBackend("snmp_discovery"), "Failed to get SnmpDiscovery backend") be := backend.GetBackend("snmp_discovery") - // Configure backend with invalid parameters - err := be.Configure(slog.Default(), nil, map[string]any{ + require.NoError(t, be.Configure(slog.Default(), nil, map[string]any{ "host": "invalid-host", - }, config.BackendCommons{}) - assert.NoError(t, err) + }, config.BackendCommons{})) ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + err := be.Start(ctx, cancel) assert.Error(t, err) -} - -func TestNetworkDiscoveryBackendDryRun(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch { - case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) - - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - assert.Equal(t, "snmp-discovery", name, "Expected command name to be snmp-discovery") - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.False(t, options.Buffered, "Expected buffered to be false") - assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } - - assert.True(t, snmpdiscovery.Register()) - be := backend.GetBackend("snmp_discovery") - - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, - } - - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, beCommons) - assert.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - assert.False(t, be.GetStartTime().IsZero()) - - err = be.RemovePolicy(policies.PolicyData{ID: "1", Name: "policy", Data: map[string]any{"k": "v"}}) - assert.NoError(t, err) - - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) mockCmd.AssertExpectations(t) } -func TestSNMPDiscoveryLogLevel(t *testing.T) { - // Create a test server for all log level tests - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { - response := StatusResponse{ - Version: "1.3.4", - StartTime: "2023-10-01T12:00:00Z", - UpTime: 123.456, - } - _ = json.NewEncoder(w).Encode(response) - } - })) - defer server.Close() +func createExecutable(t *testing.T, name string) { + t.Helper() - serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) - - // Test default log level - t.Run("LogLevelNotSet", func(t *testing.T) { - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, snmpdiscovery.Register()) - be := backend.GetBackend("snmp_discovery") - - // Configure without log_level - should use default - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that log level is not set in command args - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.NotContains(t, args, "--log-level") - return mockCmd - } - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) + tempDir := t.TempDir() + binaryPath := path.Join(tempDir, name) - // Stop the backend - _ = be.Stop(ctx) - mockCmd.AssertExpectations(t) - }) + file, err := os.Create(binaryPath) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, os.Chmod(binaryPath, 0o755)) - // Test custom log level - t.Run("CustomLogLevel", func(t *testing.T) { - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, snmpdiscovery.Register()) - be := backend.GetBackend("snmp_discovery") - - customLogLevel := "debug" - // Configure with custom log_level - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - "log_level": customLogLevel, - }, config.BackendCommons{}) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that custom log level is used in command args - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.Contains(t, args, "--log-level") - // Find the index of --log-level and check the next argument - for i, arg := range args { - if arg == "--log-level" && i+1 < len(args) { - assert.Equal(t, customLogLevel, args[i+1]) - break - } - } - return mockCmd - } - - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - // Stop the backend - _ = be.Stop(ctx) - mockCmd.AssertExpectations(t) - }) + originalPath := os.Getenv("PATH") + t.Setenv("PATH", tempDir+string(os.PathListSeparator)+originalPath) +} - // Test dry run mode includes log level - t.Run("DryRunWithLogLevel", func(t *testing.T) { - // Create a test server for dry run tests (even though dry run doesn't use HTTP, the backend still tries readiness checks) - dryRunServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch { - case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.5", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer dryRunServer.Close() - - dryRunServerURL, err := url.Parse(dryRunServer.URL) - assert.NoError(t, err) - - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - repo, err := policies.NewMemRepo() - assert.NoError(t, err) - - assert.True(t, snmpdiscovery.Register()) - be := backend.GetBackend("snmp_discovery") - - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, - } +func overrideNewCmdOptions(t *testing.T, cmd backend.Commander, assertFn func(options backend.CmdOptions, name string, args []string)) { + t.Helper() - customLogLevel := "debug" - err = be.Configure(logger, repo, map[string]any{ - "host": dryRunServerURL.Hostname(), - "port": dryRunServerURL.Port(), - "log_level": customLogLevel, - }, beCommons) - assert.NoError(t, err) - - mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) - - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Verify that log level IS included in dry run mode - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, args ...string) backend.Commander { - assert.Contains(t, args, "--log-level") - // Find the index of --log-level and check the next argument - for i, arg := range args { - if arg == "--log-level" && i+1 < len(args) { - assert.Equal(t, customLogLevel, args[i+1]) - break - } - } - // Verify dry run specific args are also present - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - // Verify non-dry-run args are NOT present - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.NotContains(t, args, "--diode-target") - return mockCmd + original := backend.NewCmdOptions + backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { + if assertFn != nil { + assertFn(options, name, args) } + return cmd + } - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) - - // Stop the backend - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) - mockCmd.AssertExpectations(t) + t.Cleanup(func() { + backend.NewCmdOptions = original }) } diff --git a/agent/backend/worker/worker.go b/agent/backend/worker/worker.go index 0769f5a..6b26ba2 100644 --- a/agent/backend/worker/worker.go +++ b/agent/backend/worker/worker.go @@ -78,7 +78,9 @@ func (d *workerBackend) Configure(logger *slog.Logger, repo policies.PolicyRepo, if d.apiHost, prs = config["host"].(string); !prs { d.apiHost = defaultAPIHost } - if d.apiPort, prs = config["port"].(string); !prs { + if port, prs := config["port"]; prs { + d.apiPort = fmt.Sprintf("%v", port) + } else { d.apiPort = defaultAPIPort } @@ -218,6 +220,13 @@ func (d *workerBackend) Start(ctx context.Context, cancelFunc context.CancelFunc var version string var readinessErr error for backoff := range readinessBackoff { + if status := d.proc.Status(); status.Complete { + err := d.proc.Stop() + if err != nil { + d.logger.Error("proc.Stop error", slog.Any("error", err)) + } + return errors.New("worker process ended unexpectedly, check log") + } version, readinessErr = d.Version() if readinessErr == nil { d.logger.Info("worker readiness ok, got version ", diff --git a/agent/backend/worker/worker_test.go b/agent/backend/worker/worker_test.go index ed01994..5000229 100644 --- a/agent/backend/worker/worker_test.go +++ b/agent/backend/worker/worker_test.go @@ -8,10 +8,13 @@ import ( "net/http/httptest" "net/url" "os" + "path" "strings" + "sync/atomic" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netboxlabs/orb-agent/agent/backend" "github.com/netboxlabs/orb-agent/agent/backend/mocks" @@ -20,266 +23,263 @@ import ( "github.com/netboxlabs/orb-agent/agent/policies" ) -type StatusResponse struct { - StartTime string `json:"start_time"` - Version string `json:"version"` - UpTime float64 `json:"up_time"` -} - func TestWorkerBackendStart(t *testing.T) { - // Create server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if r.URL.Path == "/api/v1/status" { - response := StatusResponse{ - Version: "1.3.4", - StartTime: "2023-10-01T12:00:00Z", - UpTime: 123.456, + switch { + case r.URL.Path == "/api/v1/status": + response := map[string]any{ + "version": "1.3.4", + "start_time": "2023-10-01T12:00:00Z", + "up_time": 123.456, + "up_time_min": 123.456, + "up_time_secs": 123.456, } w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if r.URL.Path == "/api/v1/capabilities" { - capabilities := map[string]any{ - "capability": true, - } + require.NoError(t, json.NewEncoder(w).Encode(response)) + case r.URL.Path == "/api/v1/capabilities": w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(capabilities) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } else if strings.Contains(r.URL.Path, "/api/v1/policies") { - switch r.Method { - case http.MethodPost: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy applied successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - case http.MethodDelete: - w.WriteHeader(http.StatusOK) - response := map[string]any{ - "status": "success", - "message": "Policy removed successfully", - } - err := json.NewEncoder(w).Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - } else { + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{"capability": true})) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) + default: w.WriteHeader(http.StatusNotFound) } })) defer server.Close() - // Parse server URL serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + require.NoError(t, err) + + createExecutable(t, "orb-worker") - // Create logger logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - // Create a mock repository repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) - // Create a mock command mockCmd := &mocks.MockCmd{} mocks.SetupSuccessfulProcess(mockCmd, 12345) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - // Assert that the correct parameters were passed + overrideNewCmdOptions(t, mockCmd, func(options backend.CmdOptions, name string, args []string) { assert.Equal(t, "orb-worker", name, "Expected command name to be orb-worker") - assert.Contains(t, args, "--port", "Expected args to contain port") - assert.Contains(t, args, "--host", "Expected args to contain host") assert.False(t, options.Buffered, "Expected buffered to be false") assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } + assert.Contains(t, args, "--host", "Expected args to contain host flag") + assert.Contains(t, args, serverURL.Hostname(), "Expected args to contain host value") + assert.Contains(t, args, "--port", "Expected args to contain port flag") + assert.Contains(t, args, serverURL.Port(), "Expected args to contain port value") + assert.Contains(t, args, "--diode-target", "Expected args to contain diode target flag") + assert.Contains(t, args, "worker-target", "Expected args to contain diode target value") + assert.Contains(t, args, "--diode-client-id", "Expected args to contain diode client id flag") + assert.Contains(t, args, "worker-client", "Expected args to contain diode client id value") + assert.Contains(t, args, "--diode-client-secret", "Expected args to contain diode client secret flag") + assert.Contains(t, args, "worker-secret", "Expected args to contain diode client secret value") + assert.Contains(t, args, "--diode-app-name-prefix", "Expected args to contain diode app name prefix flag") + assert.Contains(t, args, "worker-agent", "Expected args to contain diode app name prefix value") + assert.Contains(t, args, "--otel-endpoint", "Expected args to contain otel endpoint flag") + assert.Contains(t, args, "collector:4317", "Expected args to contain otel endpoint value") + }) assert.True(t, worker.Register(), "Failed to register Worker backend") - assert.True(t, backend.HaveBackend("worker"), "Failed to get Worker backend") be := backend.GetBackend("worker") - // Configure backend - err = be.Configure(logger, repo, map[string]any{ - "host": serverURL.Hostname(), - "port": serverURL.Port(), - }, config.BackendCommons{}) - assert.NoError(t, err) - assert.Equal(t, backend.Unknown, be.GetInitialState()) - // Start the backend + commons := config.BackendCommons{} + commons.Otlp.Grpc = "collector:4317" + commons.Diode.Target = "default-target" + commons.Diode.ClientID = "default-client" + commons.Diode.ClientSecret = "default-secret" + commons.Diode.AgentName = "default-agent" + commons.Diode.DryRunOutputDir = "/tmp/default" + + err = be.Configure(logger, repo, map[string]any{ + "host": serverURL.Hostname(), + "port": serverURL.Port(), + "target": "worker-target", + "client_id": "worker-client", + "client_secret": "worker-secret", + "agent_name": "worker-agent", + "dry_run": false, + "dry_run_output_dir": "/tmp/worker", + }, commons) + require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) + defer cancel() + + require.NoError(t, be.Start(ctx, cancel)) - // Assert successful start - assert.NoError(t, err) + startTime := be.GetStartTime() + assert.False(t, startTime.IsZero(), "Expected start time to be set") - // Get Running status status, _, err := be.GetRunningStatus() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, backend.Running, status, "Expected backend to be running") - // Get capabilities capabilities, err := be.GetCapabilities() - assert.NoError(t, err) - assert.Equal(t, capabilities["capability"], true, "Expected capability to be true") + require.NoError(t, err) + assert.Equal(t, true, capabilities["capability"], "Expected capability to be true") + + version, err := be.Version() + require.NoError(t, err) + assert.Equal(t, "1.3.4", version, "Expected version to match response") data := policies.PolicyData{ ID: "dummy-policy-id", Name: "dummy-policy-name", Data: map[string]any{"key": "value"}, } - // Apply policy - err = be.ApplyPolicy(data, false) - assert.NoError(t, err) + require.NoError(t, be.ApplyPolicy(data, false)) - // Update policy - err = be.ApplyPolicy(data, true) - assert.NoError(t, err) + updatedData := policies.PolicyData{ + ID: data.ID, + Name: "dummy-policy-updated", + Data: map[string]any{"key": "value"}, + PreviousPolicyData: &policies.PolicyData{ + Name: data.Name, + }, + } + require.NoError(t, be.ApplyPolicy(updatedData, true)) - // Assert restart - err = be.FullReset(ctx) - assert.NoError(t, err) + require.NoError(t, be.FullReset(ctx)) - // Verify expectations mockCmd.AssertExpectations(t) } -func TestWorkerBackendCompleted(t *testing.T) { - // Create a mock command that simulates a failure - mockCmd := &mocks.MockCmd{} - mocks.SetupCompletedProcess(mockCmd, 0, nil) - // Save original function and restore after test - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() - - // Override NewCmdOptions to return our mock - backend.NewCmdOptions = func(_ backend.CmdOptions, _ string, _ ...string) backend.Commander { - return mockCmd - } - - assert.True(t, worker.Register(), "Failed to register Worker backend") - - assert.True(t, backend.HaveBackend("worker"), "Failed to get Worker backend") - - be := backend.GetBackend("worker") - - // Configure backend with invalid parameters - err := be.Configure(slog.Default(), nil, map[string]any{ - "host": "invalid-host", - }, config.BackendCommons{}) - assert.NoError(t, err) +func TestWorkerGetRunningStatusAPIFailure(t *testing.T) { + var statusCalls atomic.Int32 - ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - - assert.Error(t, err) -} - -func TestWorkerBackendDryRun(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") + switch { case r.URL.Path == "/api/v1/status": - response := StatusResponse{Version: "1.3.4", StartTime: "2023-10-01T12:00:00Z", UpTime: 123.456} - _ = json.NewEncoder(w).Encode(response) - case r.URL.Path == "/api/v1/capabilities": - capabilities := map[string]any{"capability": true} - _ = json.NewEncoder(w).Encode(capabilities) - case strings.HasPrefix(r.URL.Path, "/api/v1/policies"): - _ = json.NewEncoder(w).Encode(map[string]any{"status": "ok"}) + if statusCalls.Add(1) == 1 { + response := map[string]any{ + "version": "1.3.4", + "start_time": "2023-10-01T12:00:00Z", + "up_time_min": 1.5, + } + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(response)) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + _, err := w.Write([]byte(`{"detail":"unavailable"}`)) + require.NoError(t, err) + case strings.Contains(r.URL.Path, "/api/v1/policies"): + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + "status": "success", + "message": "Policy operation successful", + })) default: - w.WriteHeader(http.StatusNotFound) + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte(`{}`)) + require.NoError(t, err) } })) defer server.Close() serverURL, err := url.Parse(server.URL) - assert.NoError(t, err) + require.NoError(t, err) + + createExecutable(t, "orb-worker") logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + repo, err := policies.NewMemRepo() - assert.NoError(t, err) + require.NoError(t, err) mockCmd := &mocks.MockCmd{} - mocks.SetupSuccessfulProcess(mockCmd, 12345) + mocks.SetupSuccessfulProcess(mockCmd, 54321) - originalNewCmdOptions := backend.NewCmdOptions - defer func() { - backend.NewCmdOptions = originalNewCmdOptions - }() + overrideNewCmdOptions(t, mockCmd, nil) - backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { - assert.Equal(t, "orb-worker", name, "Expected command name to be orb-worker") - assert.Contains(t, args, "--dry-run") - assert.Contains(t, args, "--dry-run-output-dir") - assert.NotContains(t, args, "--host") - assert.NotContains(t, args, "--port") - assert.False(t, options.Buffered, "Expected buffered to be false") - assert.True(t, options.Streaming, "Expected streaming to be true") - return mockCmd - } + assert.True(t, worker.Register(), "Failed to register Worker backend") - assert.True(t, worker.Register()) be := backend.GetBackend("worker") - beCommons := config.BackendCommons{ - Diode: struct { - Target string `yaml:"target"` - ClientID string `yaml:"client_id"` - ClientSecret string `yaml:"client_secret"` - AgentName string `yaml:"agent_name"` - DryRun bool `yaml:"dry_run"` - DryRunOutputDir string `yaml:"dry_run_output_dir"` - }{ - DryRun: true, - DryRunOutputDir: "/tmp/dry-run-output", - }, - } - err = be.Configure(logger, repo, map[string]any{ "host": serverURL.Hostname(), "port": serverURL.Port(), - }, beCommons) - assert.NoError(t, err) + }, config.BackendCommons{}) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) - err = be.Start(ctx, cancel) - assert.NoError(t, err) + defer cancel() - assert.False(t, be.GetStartTime().IsZero()) + require.NoError(t, be.Start(ctx, cancel)) - err = be.RemovePolicy(policies.PolicyData{ID: "1", Name: "policy", Data: map[string]any{"k": "v"}}) - assert.NoError(t, err) + status, message, err := be.GetRunningStatus() + assert.Equal(t, backend.BackendError, status, "Expected backend to report API failure") + assert.Equal(t, "process running, REST API unavailable", message) + assert.Error(t, err) + require.NoError(t, be.Stop(ctx)) - err = be.Stop(context.WithValue(context.Background(), config.ContextKey("routine"), "test")) - assert.NoError(t, err) + mockCmd.AssertExpectations(t) +} + +func TestWorkerBackendCompleted(t *testing.T) { + mockCmd := &mocks.MockCmd{} + mocks.SetupCompletedProcess(mockCmd, 0, nil) + + overrideNewCmdOptions(t, mockCmd, nil) + + assert.True(t, worker.Register(), "Failed to register Worker backend") + assert.True(t, backend.HaveBackend("worker"), "Failed to get Worker backend") + + be := backend.GetBackend("worker") + + require.NoError(t, be.Configure(slog.Default(), nil, map[string]any{ + "host": "invalid-host", + }, config.BackendCommons{})) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := be.Start(ctx, cancel) + assert.Error(t, err) mockCmd.AssertExpectations(t) } + +func createExecutable(t *testing.T, name string) { + t.Helper() + + tempDir := t.TempDir() + binaryPath := path.Join(tempDir, name) + + file, err := os.Create(binaryPath) + require.NoError(t, err) + require.NoError(t, file.Close()) + require.NoError(t, os.Chmod(binaryPath, 0o755)) + + originalPath := os.Getenv("PATH") + t.Setenv("PATH", tempDir+string(os.PathListSeparator)+originalPath) +} + +func overrideNewCmdOptions(t *testing.T, cmd backend.Commander, assertFn func(options backend.CmdOptions, name string, args []string)) { + t.Helper() + + original := backend.NewCmdOptions + backend.NewCmdOptions = func(options backend.CmdOptions, name string, args ...string) backend.Commander { + if assertFn != nil { + assertFn(options, name, args) + } + return cmd + } + + t.Cleanup(func() { + backend.NewCmdOptions = original + }) +} diff --git a/go.mod b/go.mod index f4eebc1..6acca36 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.24.4 require ( github.com/eclipse/paho.golang v0.22.0 - github.com/eclipse/paho.mqtt.golang v1.5.0 github.com/go-cmd/cmd v1.4.3 github.com/go-co-op/gocron/v2 v2.15.0 github.com/go-git/go-git/v5 v5.14.0 @@ -99,7 +98,6 @@ require ( golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b // indirect golang.org/x/net v0.41.0 // indirect - golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.26.0 // indirect golang.org/x/time v0.11.0 // indirect diff --git a/go.sum b/go.sum index 2c09202..21d370f 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,6 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/eclipse/paho.golang v0.22.0 h1:JhhUngr8TBlyUZDZw/L6WVayPi9qmSmdWeki48i5AVE= github.com/eclipse/paho.golang v0.22.0/go.mod h1:9ZiYJ93iEfGRJri8tErNeStPKLXIGBHiqbHV74t5pqI= -github.com/eclipse/paho.mqtt.golang v1.5.0 h1:EH+bUVJNgttidWFkLLVKaQPGmkTUfQQqjOsyvMGvD6o= -github.com/eclipse/paho.mqtt.golang v1.5.0/go.mod h1:du/2qNQVqJf/Sqs4MEL77kR8QTqANF7XU7Fk0aOTAgk= github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o= github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -258,8 +256,6 @@ golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=