diff --git a/agent/agent.go b/agent/agent.go index 54ccec8..8c5c53a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -57,6 +57,9 @@ func New(logger *slog.Logger, c config.Config) (Agent, error) { return nil, err } + // Pass a background context to the config manager at construction time. The + // manager keeps its own copy and later derives child contexts from the + // runtime context supplied in Agent.Start. cm := configmgr.New(logger, pm, c.OrbAgent.ConfigManager.Active) return &orbAgent{ diff --git a/agent/config/types.go b/agent/config/types.go index 1753697..2a7778b 100644 --- a/agent/config/types.go +++ b/agent/config/types.go @@ -41,8 +41,6 @@ type FleetManager struct { SkipTLS bool `yaml:"skip_tls"` ClientID string `yaml:"client_id"` ClientSecret string `yaml:"client_secret"` - AgentID string `yaml:"agent_id"` - MQTTURL string `yaml:"mqtt_url,omitempty"` } // Sources represents the configuration for manager sources, including cloud, local and git. diff --git a/agent/configmgr/fleet.go b/agent/configmgr/fleet.go index ebdbc1e..7eb1248 100644 --- a/agent/configmgr/fleet.go +++ b/agent/configmgr/fleet.go @@ -27,9 +27,10 @@ import ( var _ Manager = (*fleetConfigManager)(nil) type fleetConfigManager struct { - logger *slog.Logger - pMgr policymgr.PolicyManager - heartbeater *heartbeater + logger *slog.Logger + pMgr policymgr.PolicyManager + heartbeater *heartbeater + connectionManager *autopaho.ConnectionManager } const ( @@ -144,21 +145,18 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st // use MQTT URL from token response or fallback to config mqttURL := jwtClaims.MqttURL if mqttURL == "" { - mqttURL = cfg.OrbAgent.ConfigManager.Sources.Fleet.MQTTURL - } - if mqttURL == "" { - return fmt.Errorf("no MQTT URL provided in token response or config") + return fmt.Errorf("no MQTT URL provided in token response") } // use the generated topics to connect over MQTT v5 - err = fleetManager.connect(ctx, mqttURL, token.AccessToken, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone) + err = fleetManager.connect(ctx, mqttURL, token.AccessToken, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone, cfg.OrbAgent.Labels) if err != nil { return err } return nil } -func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTURL, token string, topics tokenResponseTopics, backends map[string]backend.Backend, clientID, zone string) error { +func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTURL, token string, topics tokenResponseTopics, backends map[string]backend.Backend, clientID, zone string, labels map[string]string) error { // Parse the ORB URL serverURL, err := url.Parse(fleetMQTTURL) if err != nil { @@ -177,17 +175,16 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR OnConnectionUp: func(cm *autopaho.ConnectionManager, _ *paho.Connack) { fleetManager.logger.Info("MQTT connection established", "server", serverURL.String()) - // //Subscribe to "mytopic" when connection is established - // _, err := cm.Subscribe(context.Background(), &paho.Subscribe{ - // Subscriptions: []paho.SubscribeOptions{ - // {Topic: topics.Inbox, QoS: 1}, - // }, - // }) - // if err != nil { - // fleetManager.logger.Error("failed to subscribe", "topic", topics.Inbox, "error", err) - // } else { - // fleetManager.logger.Info("successfully subscribed", "topic", topics.Inbox) - // } + _, err := cm.Subscribe(context.Background(), &paho.Subscribe{ + Subscriptions: []paho.SubscribeOptions{ + {Topic: topics.Inbox, QoS: 1}, + }, + }) + if err != nil { + fleetManager.logger.Error("failed to subscribe", "topic", topics.Inbox, "error", err) + } else { + fleetManager.logger.Info("successfully subscribed", "topic", topics.Inbox) + } // start heartbeat loop bound to the same connection-level context go fleetManager.heartbeater.sendHeartbeats(ctx, func() {}, func(ctx context.Context, payload []byte) error { @@ -213,7 +210,7 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR return nil }, clientID) - go fleetManager.sendCapabilities(ctx, backends, func(ctx context.Context, payload []byte) error { + go fleetManager.sendCapabilities(ctx, backends, labels, func(ctx context.Context, payload []byte) error { _, err := cm.Publish(ctx, &paho.Publish{ Topic: topics.Capabilities, Payload: payload, @@ -222,6 +219,28 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR }) if err != nil { // TODO: reconnect? + fleetManager.logger.Error("failed to publish capabilities", "error", err) + return err + } + + fleetManager.logger.Debug("capabilities sent", + "topic", topics.Capabilities, + "payload", string(payload), + ) + return nil + }) + + // TODO: this is a hack to work around the race condition of capabilities not being processed by the time we request group memberships + time.Sleep(10 * time.Second) + go fleetManager.sendGroupMembershipsRequest(ctx, func(ctx context.Context, payload []byte) error { + _, err := cm.Publish(ctx, &paho.Publish{ + Topic: topics.Outbox, + Payload: payload, + QoS: 1, + Retain: false, + }) + if err != nil { + fleetManager.logger.Error("failed to publish group memberships request", "error", err) return err } return nil @@ -234,14 +253,17 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR ClientID: clientID, OnPublishReceived: []func(paho.PublishReceived) (bool, error){ func(pr paho.PublishReceived) (bool, error) { - messageType := pr.Packet.Properties.User.Get(messageTypeUserPropertyKey) // Log any published messages to subscribed topics - fleetManager.logger.Info("received MQTT message", - "topic", pr.Packet.Topic, - "payload", string(pr.Packet.Payload), - "message_type", messageType) + fleetManager.logger.Info("received MQTT message", "topic", pr.Packet.Topic) + + orgID := strings.Split(pr.Packet.Topic, "/")[1] + var rpc messages.RPC + if err := json.Unmarshal(pr.Packet.Payload, &rpc); err != nil { + fleetManager.logger.Error("failed to unmarshal RPC", "error", err) + return true, nil + } - fleetManager.dispatchToHandlers(messageType, pr.Packet.Payload) + fleetManager.dispatchToHandlers(rpc.Func, rpc, orgID) return true, nil }, @@ -257,7 +279,7 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR } // Create and start the connection manager using the long-lived context. - connectionManager, err := autopaho.NewConnection(ctx, cfg) + fleetManager.connectionManager, err = autopaho.NewConnection(ctx, cfg) if err != nil { fleetManager.logger.Error("failed to create MQTT connection", "error", err) return err @@ -268,7 +290,7 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - err = connectionManager.AwaitConnection(waitCtx) + err = fleetManager.connectionManager.AwaitConnection(waitCtx) if err != nil { fleetManager.logger.Error("failed to establish initial MQTT connection", "error", err) return err @@ -278,10 +300,29 @@ func (fleetManager *fleetConfigManager) connect(ctx context.Context, fleetMQTTUR return nil } -func (fleetManager *fleetConfigManager) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, publishFunc func(ctx context.Context, payload []byte) error) { +func (fleetManager *fleetConfigManager) sendGroupMembershipsRequest(ctx context.Context, publishFunc func(ctx context.Context, payload []byte) error) { + body, err := json.Marshal(messages.RPC{ + // SchemaVersion: messages.CurrentRPCSchemaVersion, // TODO: add schema version check later + Func: "group_membership_req", + Payload: messages.SendGroupMembershipsRequest{}, + }) + if err != nil { + fleetManager.logger.Error("backend failed to marshal capabilities, skipping", "error", err) + return + } + + fleetManager.logger.Info("sending group memberships request", "value", string(body)) + err = publishFunc(ctx, body) + if err != nil { + fleetManager.logger.Error("error sending group memberships request", "error", err) + } + fleetManager.logger.Info("group memberships request sent", "value", string(body)) +} + +func (fleetManager *fleetConfigManager) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, labels map[string]string, publishFunc func(ctx context.Context, payload []byte) error) { capabilities := messages.Capabilities{ SchemaVersion: messages.CurrentCapabilitiesSchemaVersion, - // AgentTags: fleetManager.config.OrbAgent.Tags, // TODO: add tags + AgentLabels: labels, OrbAgent: messages.OrbAgentInfo{ Version: version.GetBuildVersion(), }, @@ -318,14 +359,40 @@ func (fleetManager *fleetConfigManager) sendCapabilities(ctx context.Context, ba } } -func (fleetManager *fleetConfigManager) dispatchToHandlers(_ string, _ []byte) { - // TODO: dispatch to handlers - // switch messageType { - // case "config": - // fleetManager.handleConfig(payload) - // case "policy": - // fleetManager.handlePolicy(payload) - // } +func (fleetManager *fleetConfigManager) dispatchToHandlers(messageType string, rpc messages.RPC, orgID string) { + switch messageType { + case "group_membership": + fleetManager.handleGroupMemberships(rpc, orgID) + default: + fleetManager.logger.Debug("unknown message type", "message_type", messageType) + } +} + +func (fleetManager *fleetConfigManager) handleGroupMemberships(rpc messages.RPC, orgID string) { + fleetManager.logger.Debug("handling group memberships", "payload", rpc.Payload) + payloadJSON, err := json.Marshal(rpc.Payload) + if err != nil { + fleetManager.logger.Error("failed to marshal payload", "error", err) + return + } + groupMeberships := messages.GroupMemberships{} + if err := json.Unmarshal(payloadJSON, &groupMeberships); err != nil { + fleetManager.logger.Error("failed to unmarshal payload", "error", err) + return + } + + for _, group := range groupMeberships.Groups { + fleetManager.logger.Info("subscribing to group", "group", group) + _, err := fleetManager.connectionManager.Subscribe(context.Background(), &paho.Subscribe{ + Subscriptions: []paho.SubscribeOptions{ + {Topic: groupTopic(orgID, group.GroupID), QoS: 1}, + }, + }) + if err != nil { + fleetManager.logger.Error("failed to subscribe to group", "error", err) + } + fleetManager.logger.Info("subscribed to group topic for group ID", "group_id", group.GroupID) + } } type tokenResponseTopics struct { diff --git a/agent/configmgr/fleet_test.go b/agent/configmgr/fleet_test.go index 3cf4415..f909e4d 100644 --- a/agent/configmgr/fleet_test.go +++ b/agent/configmgr/fleet_test.go @@ -548,7 +548,7 @@ func TestFleetConfigManager_Connect_InvalidURL(t *testing.T) { // Act with invalid URL backends := make(map[string]backend.Backend) trt := tokenResponseTopics{Inbox: "test/topic"} - err := fleetManager.connect(context.Background(), "://invalid-url", "test_token", trt, backends, "test-agent-id", "test-zone") + err := fleetManager.connect(context.Background(), "://invalid-url", "test_token", trt, backends, "test-agent-id", "test-zone", map[string]string{}) // Assert assert.Error(t, err) @@ -569,7 +569,7 @@ func TestFleetConfigManager_Connect_ValidURL(t *testing.T) { // Timeout after 3 seconds ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - err := fleetManager.connect(ctx, "mqtt://localhost:1883", "test_token", trt2, backends, "test-agent-id", "test-zone") + err := fleetManager.connect(ctx, "mqtt://localhost:1883", "test_token", trt2, backends, "test-agent-id", "test-zone", map[string]string{}) // Assert - we expect connection to fail since no server is running, // but URL parsing should succeed @@ -666,12 +666,17 @@ func TestFleetConfigManager_DispatchToHandlers(t *testing.T) { defer fleetManager.heartbeater.hbTicker.Stop() // Act - currently this method is a TODO, so it should not panic - payload := []byte(`{"test": "data"}`) + payload := map[string]any{"test": "data"} + rpc := messages.RPC{ + // SchemaVersion: messages.CurrentRPCSchemaVersion, + Func: "config", + Payload: payload, + } // This should not panic since it's currently empty implementation - fleetManager.dispatchToHandlers("config", payload) - fleetManager.dispatchToHandlers("policy", payload) - fleetManager.dispatchToHandlers("unknown", payload) + fleetManager.dispatchToHandlers("config", rpc, "test-org") + fleetManager.dispatchToHandlers("policy", rpc, "test-org") + fleetManager.dispatchToHandlers("unknown", rpc, "test-org") // Assert - reaching this point means no panic occurred assert.True(t, true, "dispatchToHandlers should handle all message types without panic") @@ -838,9 +843,10 @@ func TestFleetConfigManager_SendCapabilities_Success(t *testing.T) { } ctx := context.Background() + labels := map[string]string{} // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -894,6 +900,8 @@ func TestFleetConfigManager_SendCapabilities_BackendVersionError(t *testing.T) { "backend2": mockBackend2, } + labels := map[string]string{} + var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -903,7 +911,7 @@ func TestFleetConfigManager_SendCapabilities_BackendVersionError(t *testing.T) { ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -943,6 +951,8 @@ func TestFleetConfigManager_SendCapabilities_BackendCapabilitiesError(t *testing "backend2": mockBackend2, } + labels := map[string]string{} + var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -952,7 +962,7 @@ func TestFleetConfigManager_SendCapabilities_BackendCapabilitiesError(t *testing ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -985,6 +995,8 @@ func TestFleetConfigManager_SendCapabilities_PublishError(t *testing.T) { "backend1": mockBackend1, } + labels := map[string]string{} + publishError := errors.New("publish failed") publishFunc := func(_ context.Context, _ []byte) error { return publishError @@ -993,7 +1005,7 @@ func TestFleetConfigManager_SendCapabilities_PublishError(t *testing.T) { ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert assert.Equal(t, publishError, publishFunc(ctx, []byte{})) @@ -1009,7 +1021,7 @@ func TestFleetConfigManager_SendCapabilities_EmptyBackends(t *testing.T) { defer fleetManager.heartbeater.hbTicker.Stop() backends := map[string]backend.Backend{} // Empty backends - + labels := map[string]string{} var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -1019,7 +1031,7 @@ func TestFleetConfigManager_SendCapabilities_EmptyBackends(t *testing.T) { ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -1044,6 +1056,8 @@ func TestFleetConfigManager_SendCapabilities_AllBackendsFail(t *testing.T) { mockBackend1 := &mockBackend{} mockBackend2 := &mockBackend{} + labels := map[string]string{} + mockBackend1.On("Version").Return("", errors.New("version error")) mockBackend2.On("Version").Return("1.0.0", nil) mockBackend2.On("GetCapabilities").Return(map[string]any(nil), errors.New("capabilities error")) @@ -1062,7 +1076,7 @@ func TestFleetConfigManager_SendCapabilities_AllBackendsFail(t *testing.T) { ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -1099,6 +1113,8 @@ func TestFleetConfigManager_SendCapabilities_CapabilitiesStructure(t *testing.T) "test_backend": mockBackend1, } + labels := map[string]string{} + var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -1108,7 +1124,7 @@ func TestFleetConfigManager_SendCapabilities_CapabilitiesStructure(t *testing.T) ctx := context.Background() // Act - fleetManager.sendCapabilities(ctx, backends, publishFunc) + fleetManager.sendCapabilities(ctx, backends, labels, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -1180,8 +1196,6 @@ func TestFleetConfigManager_Start_WithJWTTopicGeneration(t *testing.T) { SkipTLS: true, ClientID: "test_client_id", ClientSecret: "test_client_secret", - AgentID: "test-agent-123", - MQTTURL: "mqtt://fallback.example.com:1883", }, }, }, diff --git a/agent/configmgr/jwt_claims.go b/agent/configmgr/jwt_claims.go index 5fcf6ed..e0a0619 100644 --- a/agent/configmgr/jwt_claims.go +++ b/agent/configmgr/jwt_claims.go @@ -2,7 +2,6 @@ package configmgr import ( "fmt" - "strings" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" @@ -70,30 +69,3 @@ func parseJWTClaims(tokenString string) (*JWTClaims, error) { return jwtClaims, nil } - -// fillTopicTemplate replaces placeholders in a topic template with actual values -func fillTopicTemplate(template string, claims *JWTClaims) string { - result := template - result = strings.ReplaceAll(result, "{org_id}", claims.OrgID) - result = strings.ReplaceAll(result, "{agent_id}", claims.AgentID) - return result -} - -const ( - heartbeatTemplate = "orgs/{org_id}/agents/{agent_id}/heartbeats" - capabilitiesTemplate = "orgs/{org_id}/agents/{agent_id}/capabilities" - inboxTemplate = "orgs/{org_id}/agents/{agent_id}/inbox" - outboxTemplate = "orgs/{org_id}/agents/{agent_id}/outbox" -) - -// generateTopicsFromTemplate creates actual topic names from templates using JWT claims and config agent_id -func generateTopicsFromTemplate(jwtClaims *JWTClaims) (*tokenResponseTopics, error) { - topics := &tokenResponseTopics{ - Heartbeat: fillTopicTemplate(heartbeatTemplate, jwtClaims), - Capabilities: fillTopicTemplate(capabilitiesTemplate, jwtClaims), - Inbox: fillTopicTemplate(inboxTemplate, jwtClaims), - Outbox: fillTopicTemplate(outboxTemplate, jwtClaims), - } - - return topics, nil -} diff --git a/agent/configmgr/messages/fleet_messages.go b/agent/configmgr/messages/fleet_messages.go index 62be59f..532cbf5 100644 --- a/agent/configmgr/messages/fleet_messages.go +++ b/agent/configmgr/messages/fleet_messages.go @@ -84,6 +84,28 @@ const CurrentCapabilitiesSchemaVersion = "1.0" type Capabilities struct { SchemaVersion string `json:"schema_version"` OrbAgent OrbAgentInfo `json:"orb_agent"` - AgentTags map[string]string `json:"agent_tags"` + AgentLabels map[string]string `json:"agent_labels"` Backends map[string]BackendInfo `json:"backends"` } + +// GroupMemberships represents the group memberships of an agent +type GroupMemberships struct { + FullList bool `json:"full_list"` + Groups []GroupMembership `json:"groups"` +} + +// GroupMembership represents a group membership of an agent +type GroupMembership struct { + GroupID string `json:"group_id"` + Name string `json:"name"` +} + +// RPC represents a request to or from the fleet manager +type RPC struct { + SchemaVersion string `json:"schema_version"` + Func string `json:"func"` + Payload any `json:"payload"` +} + +// SendGroupMembershipsRequest represents a request to send group memberships to the fleet manager +type SendGroupMembershipsRequest struct{} diff --git a/agent/configmgr/topics.go b/agent/configmgr/topics.go new file mode 100644 index 0000000..9da7fe1 --- /dev/null +++ b/agent/configmgr/topics.go @@ -0,0 +1,40 @@ +package configmgr + +import ( + "strings" +) + +// fillTopicTemplate replaces placeholders in a topic template with actual values +func fillTopicTemplate(template string, claims *JWTClaims) string { + result := template + result = strings.ReplaceAll(result, "{org_id}", claims.OrgID) + result = strings.ReplaceAll(result, "{agent_id}", claims.AgentID) + return result +} + +const ( + heartbeatTemplate = "orgs/{org_id}/agents/{agent_id}/heartbeats" + capabilitiesTemplate = "orgs/{org_id}/agents/{agent_id}/capabilities" + inboxTemplate = "orgs/{org_id}/agents/{agent_id}/inbox" + outboxTemplate = "orgs/{org_id}/agents/{agent_id}/outbox" + + groupsTemplate = "orgs/{org_id}/groups/{group_id}" +) + +// generateTopicsFromTemplate creates actual topic names from templates using JWT claims and config agent_id +func generateTopicsFromTemplate(jwtClaims *JWTClaims) (*tokenResponseTopics, error) { + topics := &tokenResponseTopics{ + Heartbeat: fillTopicTemplate(heartbeatTemplate, jwtClaims), + Capabilities: fillTopicTemplate(capabilitiesTemplate, jwtClaims), + Inbox: fillTopicTemplate(inboxTemplate, jwtClaims), + Outbox: fillTopicTemplate(outboxTemplate, jwtClaims), + } + + return topics, nil +} + +func groupTopic(orgID, groupID string) string { + result := strings.ReplaceAll(groupsTemplate, "{org_id}", orgID) + result = strings.ReplaceAll(result, "{group_id}", groupID) + return result +}