diff --git a/agent/configmgr/fleet/connection.go b/agent/configmgr/fleet/connection.go index 1099309..cefc244 100644 --- a/agent/configmgr/fleet/connection.go +++ b/agent/configmgr/fleet/connection.go @@ -26,11 +26,12 @@ type MQTTConnection struct { // NewMQTTConnection creates a new MQTTConnection func NewMQTTConnection(logger *slog.Logger, pMgr policymgr.PolicyManager, resetChan chan struct{}, backendState backend.StateRetriever) *MQTTConnection { + groupManager := newGroupManager() return &MQTTConnection{ connectionManager: nil, logger: logger, - heartbeater: newHeartbeater(logger, backendState, pMgr), - messaging: NewMessaging(logger, pMgr, resetChan), + heartbeater: newHeartbeater(logger, backendState, pMgr, &groupManager), + messaging: NewMessaging(logger, pMgr, resetChan, &groupManager), resetChan: resetChan, } } @@ -96,6 +97,8 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok return nil }) + // Wait for capabilities to be handled + time.Sleep(10 * time.Second) go connection.messaging.sendGroupMembershipsRequest(ctx, func(ctx context.Context, payload []byte) error { _, err := cm.Publish(ctx, &paho.Publish{ Topic: topics.Outbox, diff --git a/agent/configmgr/fleet/from_rpc.go b/agent/configmgr/fleet/from_rpc.go index 01b5023..1fa5785 100644 --- a/agent/configmgr/fleet/from_rpc.go +++ b/agent/configmgr/fleet/from_rpc.go @@ -16,14 +16,16 @@ import ( type Messaging struct { logger *slog.Logger policyManager policymgr.PolicyManager + groupManager *GroupManager resetChan chan struct{} } // NewMessaging creates a new Messaging -func NewMessaging(logger *slog.Logger, policyManager policymgr.PolicyManager, resetChan chan struct{}) *Messaging { +func NewMessaging(logger *slog.Logger, policyManager policymgr.PolicyManager, resetChan chan struct{}, groupManager *GroupManager) *Messaging { return &Messaging{ logger: logger, policyManager: policyManager, + groupManager: groupManager, resetChan: resetChan, } } @@ -50,7 +52,7 @@ func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []by messaging.logger.Error("failed to unmarshal payload", "error", err) return err } - messaging.handleGroupMemberships(ctx, groupMemberships.Payload, orgID, agentID, topicActions.Subscribe, topicActions.Publish) + messaging.handleGroupMemberships(ctx, groupMemberships.Payload, orgID, agentID, topicActions) case messages.AgentPolicyRPCFunc: agentPolicies := messages.AgentPolicyRPC{} if err := json.Unmarshal(payload, &agentPolicies); err != nil { @@ -93,25 +95,29 @@ func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []by return nil } -func (messaging *Messaging) handleGroupMemberships(ctx context.Context, groupMemberships messages.GroupMembershipRPCPayload, orgID string, agentID string, subscribeFunc func(topic string) error, publishFunc func(ctx context.Context, topic string, payload []byte) error) { +func (messaging *Messaging) handleGroupMemberships(ctx context.Context, groupMemberships messages.GroupMembershipRPCPayload, orgID string, agentID string, topicActions TopicActions) { messaging.logger.Debug("handling group memberships", "payload", groupMemberships) - // if groupMemberships.FullList { - // // TODO: handle when this is the full list. We'll need to - // // - unsubscribe from all group topics not included in this request - // // - subscribe to all group topics - // } + if groupMemberships.FullList { + for _, group := range messaging.groupManager.GetAll() { + if err := topicActions.Unsubscribe(groupTopic(orgID, group.GroupID)); err != nil { + messaging.logger.Error("failed to unsubscribe from group topic", "group_id", group.GroupID, "error", err) + } + messaging.groupManager.Remove(group.GroupID) + } + } for _, group := range groupMemberships.Groups { + messaging.groupManager.Add(group) messaging.logger.Info("subscribing to group", "group", group) topic := groupTopic(orgID, group.GroupID) - err := subscribeFunc(topic) + err := topicActions.Subscribe(topic) if err != nil { messaging.logger.Error("failed to subscribe to group", "error", err) } else { messaging.logger.Info("subscribed to group topic for group ID", "group_id", group.GroupID) } } - err := messaging.sendAgentPoliciesRequest(ctx, orgID, agentID, publishFunc) + err := messaging.sendAgentPoliciesRequest(ctx, orgID, agentID, topicActions.Publish) if err != nil { messaging.logger.Error("failed to send agent policies request", "error", err) } diff --git a/agent/configmgr/fleet/from_rpc_test.go b/agent/configmgr/fleet/from_rpc_test.go index c5180fb..f508d44 100644 --- a/agent/configmgr/fleet/from_rpc_test.go +++ b/agent/configmgr/fleet/from_rpc_test.go @@ -260,7 +260,8 @@ func TestMessageHandlers_DispatchToHandlers(t *testing.T) { tt.setupMocks(mockPMgr) } resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) agentID := "agent123" mockPublishToTopic := func(_ context.Context, _ string, _ []byte) error { @@ -300,7 +301,8 @@ func TestMessageHandlers_handleGroupMemberships_Success(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -311,6 +313,9 @@ func TestMessageHandlers_handleGroupMemberships_Success(t *testing.T) { mockPublishToTopic := func(_ context.Context, _ string, _ []byte) error { return nil } + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } agentID := "agent123" orgID := "org123" groupMemberships := messages.GroupMembershipRPCPayload{ @@ -323,7 +328,11 @@ func TestMessageHandlers_handleGroupMemberships_Success(t *testing.T) { // Act ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert expectedTopics := []string{"orgs/org123/groups/group1", "orgs/org123/groups/group2"} @@ -335,7 +344,8 @@ func TestMessageHandlers_handleGroupMemberships_InvalidPayload(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -346,6 +356,9 @@ func TestMessageHandlers_handleGroupMemberships_InvalidPayload(t *testing.T) { mockPublishToTopic := func(_ context.Context, _ string, _ []byte) error { return nil } + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } agentID := "agent123" orgID := "org123" // Create an invalid payload - empty groups is the closest we can get to testing invalid data @@ -354,10 +367,13 @@ func TestMessageHandlers_handleGroupMemberships_InvalidPayload(t *testing.T) { FullList: false, Groups: []messages.GroupMembershipData{}, } - // Act - This should not panic, just handle gracefully ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert - should not subscribe to any topics due to empty groups assert.Empty(t, subscribedTopics) @@ -368,7 +384,8 @@ func TestMessageHandlers_handleGroupMemberships_EmptyGroups(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -385,10 +402,16 @@ func TestMessageHandlers_handleGroupMemberships_EmptyGroups(t *testing.T) { FullList: true, Groups: []messages.GroupMembershipData{}, // Empty groups } - + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } // Act ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert - should not subscribe to any topics due to empty groups assert.Empty(t, subscribedTopics) @@ -399,7 +422,8 @@ func TestMessageHandlers_handleGroupMemberships_JSONMarshalError(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -418,10 +442,16 @@ func TestMessageHandlers_handleGroupMemberships_JSONMarshalError(t *testing.T) { FullList: false, Groups: []messages.GroupMembershipData{}, } - + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } // Act ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert - should not subscribe to any topics due to empty groups assert.Empty(t, subscribedTopics) @@ -473,7 +503,8 @@ func TestMessageHandlers_handleGroupMemberships_ComplexPayload(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -484,6 +515,9 @@ func TestMessageHandlers_handleGroupMemberships_ComplexPayload(t *testing.T) { mockPublishToTopic := func(_ context.Context, _ string, _ []byte) error { return nil } + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } agentID := "agent123" orgID := "org123" @@ -498,7 +532,11 @@ func TestMessageHandlers_handleGroupMemberships_ComplexPayload(t *testing.T) { // Act ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert expectedTopics := []string{"orgs/org123/groups/group1", "orgs/org123/groups/group2"} @@ -510,7 +548,8 @@ func TestMessageHandlers_handleGroupMemberships_SendsAgentPoliciesRequest(t *tes logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Mock subscribeToTopic function subscribedTopics := []string{} @@ -531,6 +570,9 @@ func TestMessageHandlers_handleGroupMemberships_SendsAgentPoliciesRequest(t *tes }{topic, payload}) return nil } + mockUnsubscribeFromTopic := func(_ string) error { + return nil + } agentID := "agent123" orgID := "org123" @@ -544,7 +586,11 @@ func TestMessageHandlers_handleGroupMemberships_SendsAgentPoliciesRequest(t *tes // Act ctx := context.Background() - handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, mockSubscribeToTopic, mockPublishToTopic) + handlers.handleGroupMemberships(ctx, groupMemberships, orgID, agentID, TopicActions{ + Subscribe: mockSubscribeToTopic, + Publish: mockPublishToTopic, + Unsubscribe: mockUnsubscribeFromTopic, + }) // Assert expectedTopics := []string{"orgs/org123/groups/group1", "orgs/org123/groups/group2"} @@ -579,7 +625,8 @@ func TestNewMessageHandlers(t *testing.T) { // Act resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Assert assert.NotNil(t, handlers) @@ -593,7 +640,8 @@ func TestMessageHandlers_handleAgentPolicies_NotFullList(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Setup mock expectations mockPMgr.On("ManagePolicy", mock.MatchedBy(func(p config.PolicyPayload) bool { @@ -629,7 +677,8 @@ func TestMessageHandlers_handleAgentPolicies_FullList_RemovesOldPolicies(t *test mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Setup existing policies in repo existingPolicies := []policies.PolicyData{ @@ -693,7 +742,8 @@ func TestMessageHandlers_handleAgentPolicies_SkipsSanitizeAction(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create policy payload with sanitize action policies := []messages.AgentPolicyRPCPayload{ @@ -724,7 +774,8 @@ func TestMessageHandlers_handleAgentPolicies_FullList_GetAllFails(t *testing.T) mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Setup mock expectations - GetAll fails mockPMgr.On("GetRepo").Return(mockRepo) @@ -760,7 +811,8 @@ func TestMessageHandlers_handleAgentGroupRemoval_RemovesPolicyWhenNoGroupsRemain mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) unsubscribedTopics := []string{} mockUnsubscribeFromTopic := func(topic string) error { @@ -804,7 +856,8 @@ func TestMessageHandlers_handleAgentGroupRemoval_RemovesDatasetsWhenGroupsRemain mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) unsubscribedTopics := []string{} mockUnsubscribeFromTopic := func(topic string) error { @@ -856,7 +909,8 @@ func TestMessageHandlers_handleAgentGroupRemoval_UnsubscribeFails(t *testing.T) logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) mockUnsubscribeFromTopic := func(_ string) error { return assert.AnError @@ -881,7 +935,8 @@ func TestMessageHandlers_DispatchToHandlers_InvalidJSON(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) invalidPayload := []byte("invalid json {") @@ -903,7 +958,8 @@ func TestMessageHandlers_DispatchToHandlers_MissingFunc(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create RPC with empty Func rpc := messages.RPC{ @@ -932,7 +988,8 @@ func TestMessageHandlers_DispatchToHandlers_NilPayload(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create RPC with nil Payload rpc := messages.RPC{ @@ -961,7 +1018,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedGroupMembershipPayload(t *t logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload - using string instead of proper structure malformedPayload := []byte(`{"schema_version":"1.0","func":"group_membership","payload":"not_a_valid_structure"}`) @@ -984,7 +1042,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedAgentPolicyPayload(t *testi logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload malformedPayload := []byte(`{"schema_version":"1.0","func":"agent_policy","payload":"not_an_array"}`) @@ -1007,7 +1066,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedGroupRemovedPayload(t *test logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload malformedPayload := []byte(`{"schema_version":"1.0","func":"group_removed","payload":"not_a_structure"}`) @@ -1030,7 +1090,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedDatasetRemovedPayload(t *te logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload malformedPayload := []byte(`{"schema_version":"1.0","func":"dataset_removed","payload":"not_a_structure"}`) @@ -1054,7 +1115,8 @@ func TestMessageHandlers_handleDatasetRemoval_Success(t *testing.T) { mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Register a mock backend for testing mockBe := &mockBackend{} @@ -1095,7 +1157,8 @@ func TestMessageHandlers_handleDatasetRemoval_PolicyRetrievalFails(t *testing.T) mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Setup mock expectations - Get fails mockPMgr.On("GetRepo").Return(mockRepo) @@ -1123,7 +1186,8 @@ func TestMessageHandlers_handleDatasetRemoval_BackendNotFound(t *testing.T) { mockPMgr := &mockPolicyManager{} mockRepo := &mockPolicyRepo{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Setup mock expectations - policy exists but with nonexistent backend mockPMgr.On("GetRepo").Return(mockRepo) @@ -1154,7 +1218,8 @@ func TestMessageHandlers_handleAgentReset_NoFullReset(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) ctx := context.Background() @@ -1183,7 +1248,8 @@ func TestMessageHandlers_DispatchToHandlers_AgentReset(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) ctx := context.Background() @@ -1218,7 +1284,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedAgentResetPayload(t *testin logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload malformedPayload := []byte(`{"schema_version":"1.0","func":"agent_reset","payload":"not_a_structure"}`) @@ -1241,7 +1308,8 @@ func TestMessageHandlers_DispatchToHandlers_AgentStop(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - _ = NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + _ = NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create agent_stop RPC message rpc := messages.AgentStopRPC{ @@ -1278,7 +1346,8 @@ func TestMessageHandlers_DispatchToHandlers_MalformedAgentStopPayload(t *testing logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManager{} resetChan := make(chan struct{}, 1) - handlers := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + handlers := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create malformed payload malformedPayload := []byte(`{"schema_version":"1.0","func":"agent_stop","payload":"not_a_structure"}`) diff --git a/agent/configmgr/fleet/groups.go b/agent/configmgr/fleet/groups.go new file mode 100644 index 0000000..2a26d6f --- /dev/null +++ b/agent/configmgr/fleet/groups.go @@ -0,0 +1,44 @@ +package fleet + +import "github.com/netboxlabs/orb-agent/agent/configmgr/fleet/messages" + +// GroupRetriever provides read-only access to group memberships +type GroupRetriever interface { + GetAll() []messages.GroupMembershipData +} + +// GroupManager manages the agent's group memberships +type GroupManager struct { + groups []messages.GroupMembershipData +} + +func newGroupManager() GroupManager { + return GroupManager{ + groups: []messages.GroupMembershipData{}, + } +} + +// Add adds a group to the manager +func (gm *GroupManager) Add(group messages.GroupMembershipData) { + gm.groups = append(gm.groups, group) +} + +// RemoveAll removes all groups from the manager +func (gm *GroupManager) RemoveAll() { + gm.groups = []messages.GroupMembershipData{} +} + +// Remove removes a specific group by ID from the manager +func (gm *GroupManager) Remove(groupID string) { + for i, group := range gm.groups { + if group.GroupID == groupID { + gm.groups = append(gm.groups[:i], gm.groups[i+1:]...) + break + } + } +} + +// GetAll returns all groups managed by the manager +func (gm *GroupManager) GetAll() []messages.GroupMembershipData { + return gm.groups +} diff --git a/agent/configmgr/fleet/groups_test.go b/agent/configmgr/fleet/groups_test.go new file mode 100644 index 0000000..45f1d05 --- /dev/null +++ b/agent/configmgr/fleet/groups_test.go @@ -0,0 +1,194 @@ +package fleet + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netboxlabs/orb-agent/agent/configmgr/fleet/messages" +) + +func TestGroupManager_Add_SingleGroup(t *testing.T) { + // Arrange + gm := newGroupManager() + group := messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + } + + // Act + gm.Add(group) + + // Assert + groups := gm.GetAll() + require.Len(t, groups, 1) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "Test Group 1", groups[0].Name) +} + +func TestGroupManager_Add_Groups(t *testing.T) { + // Arrange + gm := newGroupManager() + group1 := messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + } + group2 := messages.GroupMembershipData{ + GroupID: "group-2", + Name: "Test Group 2", + } + group3 := messages.GroupMembershipData{ + GroupID: "group-3", + Name: "Test Group 3", + } + + // Act + gm.Add(group1) + gm.Add(group2) + gm.Add(group3) + + // Assert + groups := gm.GetAll() + require.Len(t, groups, 3) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-2", groups[1].GroupID) + assert.Equal(t, "group-3", groups[2].GroupID) +} + +func TestGroupManager_Add_DuplicateGroups(t *testing.T) { + // Arrange + gm := newGroupManager() + group1 := messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + } + group2 := messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + } + + // Act + gm.Add(group1) + gm.Add(group2) + + // Assert - Add allows duplicates (no deduplication) + groups := gm.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-1", groups[1].GroupID) +} + +func TestGroupManager_GetAll_EmptyGroups(t *testing.T) { + // Arrange + gm := newGroupManager() + + // Act + groups := gm.GetAll() + + // Assert + assert.NotNil(t, groups) + assert.Empty(t, groups) +} + +func TestGroupManager_RemoveAll_Success(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-3", Name: "Test Group 3"}) + + // Verify groups were added + require.Len(t, gm.GetAll(), 3) + + // Act + gm.RemoveAll() + + // Assert + groups := gm.GetAll() + assert.NotNil(t, groups) + assert.Empty(t, groups) +} + +func TestGroupManager_Remove_SingleGroup(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-3", Name: "Test Group 3"}) + + // Act + gm.Remove("group-2") + + // Assert + groups := gm.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-3", groups[1].GroupID) +} + +func TestGroupManager_Remove_FirstGroup(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-3", Name: "Test Group 3"}) + + // Act + gm.Remove("group-1") + + // Assert + groups := gm.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-2", groups[0].GroupID) + assert.Equal(t, "group-3", groups[1].GroupID) +} + +func TestGroupManager_Remove_LastGroup(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-3", Name: "Test Group 3"}) + + // Act + gm.Remove("group-3") + + // Assert + groups := gm.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-2", groups[1].GroupID) +} + +func TestGroupManager_Remove_NonExistentGroup(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + + // Act - Remove non-existent group should not panic + gm.Remove("group-999") + + // Assert - No change in groups + groups := gm.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-2", groups[1].GroupID) +} + +func TestGroupManager_GroupRetrieverInterface(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Test Group 1"}) + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Test Group 2"}) + + // Act - Use as GroupRetriever interface + var retriever GroupRetriever = &gm + + // Assert + groups := retriever.GetAll() + require.Len(t, groups, 2) + assert.Equal(t, "group-1", groups[0].GroupID) + assert.Equal(t, "group-2", groups[1].GroupID) +} diff --git a/agent/configmgr/fleet/heartbeats.go b/agent/configmgr/fleet/heartbeats.go index b3e4991..15343a0 100644 --- a/agent/configmgr/fleet/heartbeats.go +++ b/agent/configmgr/fleet/heartbeats.go @@ -16,20 +16,22 @@ const ( ) type heartbeater struct { - logger *slog.Logger - hbTicker *time.Ticker - heartbeatCtx context.Context - backendState backend.StateRetriever - policyManager policymgr.PolicyManager + logger *slog.Logger + hbTicker *time.Ticker + heartbeatCtx context.Context + backendState backend.StateRetriever + policyManager policymgr.PolicyManager + groupRetriever GroupRetriever } -func newHeartbeater(logger *slog.Logger, backendState backend.StateRetriever, policyManager policymgr.PolicyManager) *heartbeater { +func newHeartbeater(logger *slog.Logger, backendState backend.StateRetriever, policyManager policymgr.PolicyManager, groupRetriever GroupRetriever) *heartbeater { return &heartbeater{ - logger: logger, - hbTicker: time.NewTicker(heartbeatFreq), - heartbeatCtx: context.Background(), - backendState: backendState, - policyManager: policyManager, + logger: logger, + hbTicker: time.NewTicker(heartbeatFreq), + heartbeatCtx: context.Background(), + backendState: backendState, + policyManager: policyManager, + groupRetriever: groupRetriever, } } @@ -45,7 +47,7 @@ func (hb *heartbeater) sendSingleHeartbeat(ctx context.Context, heartbeatTopic s State: messages.State(messages.Online), BackendState: hb.getBackendState(), PolicyState: hb.getPolicyState(), - GroupState: make(map[string]messages.GroupStateInfo), + GroupState: hb.getGroupState(), } body, err := json.Marshal(hbData) @@ -97,6 +99,17 @@ func (hb *heartbeater) getPolicyState() map[string]messages.PolicyStateInfo { return ps } +func (hb *heartbeater) getGroupState() map[string]messages.GroupStateInfo { + gs := make(map[string]messages.GroupStateInfo) + for _, group := range hb.groupRetriever.GetAll() { + gs[group.GroupID] = messages.GroupStateInfo{ + GroupName: group.Name, + GroupID: group.GroupID, + } + } + return gs +} + // sendHeartbeats starts a goroutine that periodically issues heartbeats until the // supplied context is cancelled. The cancelFunc parameter is ignored by the // implementation but is accepted for backward-compatibility with unit tests diff --git a/agent/configmgr/fleet/heartbeats_test.go b/agent/configmgr/fleet/heartbeats_test.go index f8357b7..6ca2469 100644 --- a/agent/configmgr/fleet/heartbeats_test.go +++ b/agent/configmgr/fleet/heartbeats_test.go @@ -73,12 +73,14 @@ func createTestHeartbeater() *heartbeater { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForHeartbeat{} mockPMgr.On("GetPolicyState").Return([]policies.PolicyData{}, nil).Maybe() + groupManager := newGroupManager() return &heartbeater{ - logger: logger, - hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing - heartbeatCtx: context.Background(), - backendState: &mockBackendState{}, - policyManager: mockPMgr, + logger: logger, + hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing + heartbeatCtx: context.Background(), + backendState: &mockBackendState{}, + policyManager: mockPMgr, + groupRetriever: &groupManager, } } @@ -86,23 +88,27 @@ func createTestHeartbeaterWithBackendState(backendState *mockBackendState) *hear logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForHeartbeat{} mockPMgr.On("GetPolicyState").Return([]policies.PolicyData{}, nil).Maybe() + groupManager := newGroupManager() return &heartbeater{ - logger: logger, - hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing - heartbeatCtx: context.Background(), - backendState: backendState, - policyManager: mockPMgr, + logger: logger, + hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing + heartbeatCtx: context.Background(), + backendState: backendState, + policyManager: mockPMgr, + groupRetriever: &groupManager, } } func createTestHeartbeaterWithPolicyManager(backendState *mockBackendState, policyManager *mockPolicyManagerForHeartbeat) *heartbeater { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + groupManager := newGroupManager() return &heartbeater{ - logger: logger, - hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing - heartbeatCtx: context.Background(), - backendState: backendState, - policyManager: policyManager, + logger: logger, + hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing + heartbeatCtx: context.Background(), + backendState: backendState, + policyManager: policyManager, + groupRetriever: &groupManager, } } @@ -1009,9 +1015,10 @@ func TestNewHeartbeater_WithPolicyManager(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) backendState := &mockBackendState{} mockPMgr := &mockPolicyManagerForHeartbeat{} + groupManager := newGroupManager() // Act - hb := newHeartbeater(logger, backendState, mockPMgr) + hb := newHeartbeater(logger, backendState, mockPMgr, &groupManager) // Assert assert.NotNil(t, hb) @@ -1021,6 +1028,380 @@ func TestNewHeartbeater_WithPolicyManager(t *testing.T) { assert.NotNil(t, hb.backendState) assert.NotNil(t, hb.policyManager) assert.Equal(t, mockPMgr, hb.policyManager) + assert.NotNil(t, hb.groupRetriever) + + // Clean up ticker + hb.hbTicker.Stop() +} + +func createTestHeartbeaterWithGroupManager(groupManager *GroupManager) *heartbeater { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + mockPMgr := &mockPolicyManagerForHeartbeat{} + mockPMgr.On("GetPolicyState").Return([]policies.PolicyData{}, nil).Maybe() + return &heartbeater{ + logger: logger, + hbTicker: time.NewTicker(50 * time.Millisecond), // Short interval for testing + heartbeatCtx: context.Background(), + backendState: &mockBackendState{}, + policyManager: mockPMgr, + groupRetriever: groupManager, + } +} + +func TestHeartbeater_SendSingleHeartbeat_WithEmptyGroupState(t *testing.T) { + // Arrange + gm := newGroupManager() + hb := createTestHeartbeaterWithGroupManager(&gm) + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + testTime := time.Now() + + // Act + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + + // Assert + require.NotNil(t, capturedPayload) + + var heartbeat messages.Heartbeat + err := json.Unmarshal(capturedPayload, &heartbeat) + require.NoError(t, err) + + // Verify group state is empty but not nil + assert.NotNil(t, heartbeat.GroupState) + assert.Empty(t, heartbeat.GroupState) +} + +func TestHeartbeater_SendSingleHeartbeat_WithGroupState(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + }) + gm.Add(messages.GroupMembershipData{ + GroupID: "group-2", + Name: "Test Group 2", + }) + hb := createTestHeartbeaterWithGroupManager(&gm) + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + testTime := time.Now() + + // Act + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + + // Assert + require.NotNil(t, capturedPayload) + + var heartbeat messages.Heartbeat + err := json.Unmarshal(capturedPayload, &heartbeat) + require.NoError(t, err) + + // Verify group state is populated + assert.NotNil(t, heartbeat.GroupState) + assert.Len(t, heartbeat.GroupState, 2) + + // Check group-1 + group1, ok := heartbeat.GroupState["group-1"] + assert.True(t, ok) + assert.Equal(t, "Test Group 1", group1.GroupName) + assert.Equal(t, "group-1", group1.GroupID) + + // Check group-2 + group2, ok := heartbeat.GroupState["group-2"] + assert.True(t, ok) + assert.Equal(t, "Test Group 2", group2.GroupName) + assert.Equal(t, "group-2", group2.GroupID) +} + +func TestHeartbeater_SendSingleHeartbeat_WithCompleteState(t *testing.T) { + // Test heartbeat with backend, policy, and group states all populated + // Arrange + testTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + + // Setup backend state + backendState := &mockBackendState{ + backendState: map[string]*backend.State{ + "pktvisor": { + Status: backend.Running, + RestartCount: 1, + LastError: "", + }, + }, + } + + // Setup policy manager + mockPMgr := &mockPolicyManagerForHeartbeat{} + mockPMgr.On("GetPolicyState").Return([]policies.PolicyData{ + { + ID: "policy-1", + Name: "Test Policy", + Backend: "pktvisor", + Version: 1, + State: policies.Running, + Datasets: map[string]bool{"dataset-1": true}, + }, + }, nil) + + // Setup group manager + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + }) + + // Create heartbeater with all components + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + hb := &heartbeater{ + logger: logger, + hbTicker: time.NewTicker(50 * time.Millisecond), + heartbeatCtx: context.Background(), + backendState: backendState, + policyManager: mockPMgr, + groupRetriever: &gm, + } + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + + // Act + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + + // Assert + require.NotNil(t, capturedPayload) + + var heartbeat messages.Heartbeat + err := json.Unmarshal(capturedPayload, &heartbeat) + require.NoError(t, err) + + // Verify all state sections are present + assert.NotNil(t, heartbeat.BackendState) + assert.Len(t, heartbeat.BackendState, 1) + assert.NotNil(t, heartbeat.PolicyState) + assert.Len(t, heartbeat.PolicyState, 1) + assert.NotNil(t, heartbeat.GroupState) + assert.Len(t, heartbeat.GroupState, 1) + + // Verify backend state + backend, ok := heartbeat.BackendState["pktvisor"] + assert.True(t, ok) + assert.Equal(t, "running", backend.State) + + // Verify policy state + policy, ok := heartbeat.PolicyState["policy-1"] + assert.True(t, ok) + assert.Equal(t, "Test Policy", policy.Name) + + // Verify group state + group, ok := heartbeat.GroupState["group-1"] + assert.True(t, ok) + assert.Equal(t, "Test Group 1", group.GroupName) + + mockPMgr.AssertExpectations(t) +} + +func TestHeartbeater_SendSingleHeartbeat_GroupStateAfterRemoval(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + }) + gm.Add(messages.GroupMembershipData{ + GroupID: "group-2", + Name: "Test Group 2", + }) + hb := createTestHeartbeaterWithGroupManager(&gm) + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + testTime := time.Now() + + // Send initial heartbeat with 2 groups + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + require.NotNil(t, capturedPayload) + + var heartbeat1 messages.Heartbeat + err := json.Unmarshal(capturedPayload, &heartbeat1) + require.NoError(t, err) + assert.Len(t, heartbeat1.GroupState, 2) + + // Remove one group + gm.Remove("group-1") + + // Send second heartbeat after removal + capturedPayload = nil + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + require.NotNil(t, capturedPayload) + + var heartbeat2 messages.Heartbeat + err = json.Unmarshal(capturedPayload, &heartbeat2) + require.NoError(t, err) + + // Assert - Should only have 1 group now + assert.Len(t, heartbeat2.GroupState, 1) + _, ok := heartbeat2.GroupState["group-1"] + assert.False(t, ok) + group2, ok := heartbeat2.GroupState["group-2"] + assert.True(t, ok) + assert.Equal(t, "Test Group 2", group2.GroupName) +} + +func TestHeartbeater_SendSingleHeartbeat_GroupStateAfterRemoveAll(t *testing.T) { + // Arrange + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + }) + gm.Add(messages.GroupMembershipData{ + GroupID: "group-2", + Name: "Test Group 2", + }) + hb := createTestHeartbeaterWithGroupManager(&gm) + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + testTime := time.Now() + + // Send initial heartbeat with 2 groups + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + require.NotNil(t, capturedPayload) + + var heartbeat1 messages.Heartbeat + err := json.Unmarshal(capturedPayload, &heartbeat1) + require.NoError(t, err) + assert.Len(t, heartbeat1.GroupState, 2) + + // Remove all groups + gm.RemoveAll() + + // Send second heartbeat after removal + capturedPayload = nil + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + require.NotNil(t, capturedPayload) + + var heartbeat2 messages.Heartbeat + err = json.Unmarshal(capturedPayload, &heartbeat2) + require.NoError(t, err) + + // Assert - Should have no groups now + assert.NotNil(t, heartbeat2.GroupState) + assert.Empty(t, heartbeat2.GroupState) +} + +func TestHeartbeater_SendSingleHeartbeat_DynamicGroupUpdates(t *testing.T) { + // Test that group state reflects dynamic updates + // Arrange + gm := newGroupManager() + hb := createTestHeartbeaterWithGroupManager(&gm) + defer hb.hbTicker.Stop() + + var capturedPayload []byte + testTopic := "test/heartbeat" + publishFunc := func(_ context.Context, _ string, payload []byte) error { + capturedPayload = payload + return nil + } + + ctx := context.Background() + testTime := time.Now() + + // Heartbeat 1: No groups + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + var hb1 messages.Heartbeat + require.NoError(t, json.Unmarshal(capturedPayload, &hb1)) + assert.Empty(t, hb1.GroupState) + + // Add group + gm.Add(messages.GroupMembershipData{GroupID: "group-1", Name: "Group 1"}) + + // Heartbeat 2: 1 group + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + var hb2 messages.Heartbeat + require.NoError(t, json.Unmarshal(capturedPayload, &hb2)) + assert.Len(t, hb2.GroupState, 1) + + // Add another group + gm.Add(messages.GroupMembershipData{GroupID: "group-2", Name: "Group 2"}) + + // Heartbeat 3: 2 groups + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + var hb3 messages.Heartbeat + require.NoError(t, json.Unmarshal(capturedPayload, &hb3)) + assert.Len(t, hb3.GroupState, 2) + + // Remove one + gm.Remove("group-1") + + // Heartbeat 4: 1 group + hb.sendSingleHeartbeat(ctx, testTopic, publishFunc, "test-agent-id", testTime, messages.Online) + var hb4 messages.Heartbeat + require.NoError(t, json.Unmarshal(capturedPayload, &hb4)) + assert.Len(t, hb4.GroupState, 1) + _, ok := hb4.GroupState["group-2"] + assert.True(t, ok) +} + +func TestNewHeartbeater_WithGroupManager(t *testing.T) { + // Arrange + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) + backendState := &mockBackendState{} + mockPMgr := &mockPolicyManagerForHeartbeat{} + gm := newGroupManager() + gm.Add(messages.GroupMembershipData{ + GroupID: "group-1", + Name: "Test Group 1", + }) + + // Act + hb := newHeartbeater(logger, backendState, mockPMgr, &gm) + + // Assert + assert.NotNil(t, hb) + assert.NotNil(t, hb.groupRetriever) + + // Verify we can retrieve groups through the interface + groups := hb.groupRetriever.GetAll() + require.Len(t, groups, 1) + assert.Equal(t, "group-1", groups[0].GroupID) // Clean up ticker hb.hbTicker.Stop() diff --git a/agent/configmgr/fleet/messages/fleet_messages.go b/agent/configmgr/fleet/messages/fleet_messages.go index 1eebbd4..52cee80 100644 --- a/agent/configmgr/fleet/messages/fleet_messages.go +++ b/agent/configmgr/fleet/messages/fleet_messages.go @@ -42,8 +42,8 @@ type PolicyStateInfo struct { // GroupStateInfo contains state information for a group type GroupStateInfo struct { - GroupName string `json:"name"` - GroupChannel string `json:"channel"` + GroupName string `json:"name"` + GroupID string `json:"id"` } // Heartbeat represents an agent heartbeat message @@ -136,9 +136,8 @@ type GroupMembershipRPC struct { // GroupMembershipData contains information about a single group membership type GroupMembershipData struct { - GroupID string `json:"group_id"` - Name string `json:"name"` - ChannelID string `json:"channel_id"` + GroupID string `json:"group_id"` + Name string `json:"name"` } // GroupMembershipRPCPayload is the payload for group membership RPC messages diff --git a/agent/configmgr/fleet/to_rpc_test.go b/agent/configmgr/fleet/to_rpc_test.go index c63d497..6df7d0b 100644 --- a/agent/configmgr/fleet/to_rpc_test.go +++ b/agent/configmgr/fleet/to_rpc_test.go @@ -61,7 +61,8 @@ func TestMessaging_SendCapabilities_Success(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create mock backends mockBackend1 := &mockBackend{} @@ -134,7 +135,8 @@ func TestMessaging_SendCapabilities_BackendVersionError(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create mock backends - one succeeds, one fails on version mockBackend1 := &mockBackend{} @@ -184,7 +186,8 @@ func TestMessaging_SendCapabilities_BackendCapabilitiesError(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // Create mock backends - one succeeds, one fails on capabilities mockBackend1 := &mockBackend{} @@ -235,7 +238,8 @@ func TestMessaging_SendCapabilities_PublishError(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) mockBackend1 := &mockBackend{} mockBackend1.On("Version").Return("1.0.0", nil) @@ -268,7 +272,8 @@ func TestMessaging_SendCapabilities_EmptyBackends(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) backends := map[string]backend.Backend{} // Empty backends labels := map[string]string{} @@ -300,7 +305,8 @@ func TestMessaging_SendCapabilities_AllBackendsFail(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) // All backends fail mockBackend1 := &mockBackend{} @@ -347,7 +353,8 @@ func TestMessaging_SendCapabilities_CapabilitiesStructure(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) mockPMgr := &mockPolicyManagerForToRPC{} resetChan := make(chan struct{}, 1) - messaging := NewMessaging(logger, mockPMgr, resetChan) + groupManager := newGroupManager() + messaging := NewMessaging(logger, mockPMgr, resetChan, &groupManager) mockBackend1 := &mockBackend{} mockBackend1.On("Version").Return("test-version", nil)