diff --git a/agent/configmgr/fleet/from_rpc.go b/agent/configmgr/fleet/from_rpc.go index 794541f..59864aa 100644 --- a/agent/configmgr/fleet/from_rpc.go +++ b/agent/configmgr/fleet/from_rpc.go @@ -62,6 +62,14 @@ func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []by return err } messaging.handleAgentGroupRemoval(groupRemoved.Payload, topicActions.Unsubscribe) + + case messages.DatasetRemovedRPCFunc: + var r messages.DatasetRemovedRPC + if err := json.Unmarshal(payload, &r); err != nil { + messaging.logger.Error("error decoding dataset removal message from core", "error", messages.ErrSchemaMalformed) + return err + } + messaging.handleDatasetRemoval(r.Payload) default: messaging.logger.Debug("unknown rpc function", "func", rpc.Func) } @@ -166,3 +174,17 @@ func (messaging *Messaging) handleAgentGroupRemoval(rpc messages.GroupRemovedRPC } } } + +func (messaging *Messaging) handleDatasetRemoval(rpc messages.DatasetRemovedRPCPayload) { + policy, err := messaging.policyManager.GetRepo().Get(rpc.PolicyID) + if err != nil { + messaging.logger.Error("failed to retrieve policy", "policy_id", rpc.PolicyID, "error", err) + return + } + if !backend.HaveBackend(policy.Backend) { + messaging.logger.Error("policy backend not found", "policy_id", rpc.PolicyID, "policy_backend", policy.Backend) + return + } + be := backend.GetBackend(policy.Backend) + messaging.policyManager.RemovePolicyDataset(rpc.PolicyID, rpc.DatasetID, be) +} diff --git a/agent/configmgr/fleet/from_rpc_test.go b/agent/configmgr/fleet/from_rpc_test.go index 76895e5..5c46783 100644 --- a/agent/configmgr/fleet/from_rpc_test.go +++ b/agent/configmgr/fleet/from_rpc_test.go @@ -200,6 +200,36 @@ func TestMessageHandlers_DispatchToHandlers(t *testing.T) { expectedError: false, expectedPolicyMgrCall: true, }, + { + name: "dataset_removed message type", + messageType: "dataset_removed", + rpc: messages.RPC{ + SchemaVersion: "1.0", + Func: "dataset_removed", + Payload: map[string]any{ + "dataset_id": "dataset1", + "policy_id": "policy1", + }, + }, + orgID: "org123", + expectedTopics: []string{}, + setupMocks: func(m *mockPolicyManager) { + // Register a test backend if not already registered + if !backend.HaveBackend("test_backend_dispatch") { + mockBe := &mockBackend{} + backend.Register("test_backend_dispatch", mockBe) + } + mockRepo := &mockPolicyRepo{} + mockRepo.On("Get", "policy1").Return(policies.PolicyData{ + ID: "policy1", + Backend: "test_backend_dispatch", + }, nil) + m.On("GetRepo").Return(mockRepo) + m.On("RemovePolicyDataset", "policy1", "dataset1", mock.Anything).Return() + }, + expectedError: false, + expectedPolicyMgrCall: true, + }, { name: "unknown message type", messageType: "unknown_type", @@ -972,3 +1002,123 @@ func TestMessageHandlers_DispatchToHandlers_MalformedGroupRemovedPayload(t *test // Assert assert.Error(t, err) } + +// Test DispatchToHandlers with malformed dataset_removed payload +func TestMessageHandlers_DispatchToHandlers_MalformedDatasetRemovedPayload(t *testing.T) { + // Arrange + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + mockPMgr := &mockPolicyManager{} + handlers := NewMessaging(logger, mockPMgr) + + // Create malformed payload + malformedPayload := []byte(`{"schema_version":"1.0","func":"dataset_removed","payload":"not_a_structure"}`) + + // Act + ctx := context.Background() + err := handlers.DispatchToHandlers(ctx, malformedPayload, "org123", "agent123", TopicActions{ + Subscribe: func(_ string) error { return nil }, + Publish: func(_ context.Context, _ string, _ []byte) error { return nil }, + Unsubscribe: func(_ string) error { return nil }, + }) + + // Assert + assert.Error(t, err) +} + +// Test handleDatasetRemoval with successful dataset removal +func TestMessageHandlers_handleDatasetRemoval_Success(t *testing.T) { + // Arrange + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + mockPMgr := &mockPolicyManager{} + mockRepo := &mockPolicyRepo{} + handlers := NewMessaging(logger, mockPMgr) + + // Register a mock backend for testing + mockBe := &mockBackend{} + backend.Register("test_backend", mockBe) + defer func() { + // Clean up - manually remove from registry + // Note: There's no Unregister function, but for isolated tests this is fine + // In a real scenario, we would want an Unregister function + }() + + // Setup mock expectations + mockPMgr.On("GetRepo").Return(mockRepo) + mockRepo.On("Get", "policy1").Return(policies.PolicyData{ + ID: "policy1", + Name: "Test Policy", + Backend: "test_backend", + }, nil) + mockPMgr.On("RemovePolicyDataset", "policy1", "dataset1", mockBe).Return() + + // Create dataset removal payload + datasetRemoval := messages.DatasetRemovedRPCPayload{ + DatasetID: "dataset1", + PolicyID: "policy1", + } + + // Act + handlers.handleDatasetRemoval(datasetRemoval) + + // Assert + mockPMgr.AssertExpectations(t) + mockRepo.AssertExpectations(t) +} + +// Test handleDatasetRemoval when policy retrieval fails +func TestMessageHandlers_handleDatasetRemoval_PolicyRetrievalFails(t *testing.T) { + // Arrange + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + mockPMgr := &mockPolicyManager{} + mockRepo := &mockPolicyRepo{} + handlers := NewMessaging(logger, mockPMgr) + + // Setup mock expectations - Get fails + mockPMgr.On("GetRepo").Return(mockRepo) + mockRepo.On("Get", "policy1").Return(policies.PolicyData{}, assert.AnError) + + // Create dataset removal payload + datasetRemoval := messages.DatasetRemovedRPCPayload{ + DatasetID: "dataset1", + PolicyID: "policy1", + } + + // Act + handlers.handleDatasetRemoval(datasetRemoval) + + // Assert - should return early without calling RemovePolicyDataset + mockPMgr.AssertNotCalled(t, "RemovePolicyDataset", mock.Anything, mock.Anything, mock.Anything) + mockPMgr.AssertExpectations(t) + mockRepo.AssertExpectations(t) +} + +// Test handleDatasetRemoval when backend not found +func TestMessageHandlers_handleDatasetRemoval_BackendNotFound(t *testing.T) { + // Arrange + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + mockPMgr := &mockPolicyManager{} + mockRepo := &mockPolicyRepo{} + handlers := NewMessaging(logger, mockPMgr) + + // Setup mock expectations - policy exists but with nonexistent backend + mockPMgr.On("GetRepo").Return(mockRepo) + mockRepo.On("Get", "policy1").Return(policies.PolicyData{ + ID: "policy1", + Name: "Test Policy", + Backend: "nonexistent_backend", + }, nil) + + // Create dataset removal payload + datasetRemoval := messages.DatasetRemovedRPCPayload{ + DatasetID: "dataset1", + PolicyID: "policy1", + } + + // Act + handlers.handleDatasetRemoval(datasetRemoval) + + // Assert - should return early without calling RemovePolicyDataset + mockPMgr.AssertNotCalled(t, "RemovePolicyDataset", mock.Anything, mock.Anything, mock.Anything) + mockPMgr.AssertExpectations(t) + mockRepo.AssertExpectations(t) +}