diff --git a/agent/configmgr/fleet.go b/agent/configmgr/fleet.go index 3c18f89..10125a6 100644 --- a/agent/configmgr/fleet.go +++ b/agent/configmgr/fleet.go @@ -6,6 +6,8 @@ import ( "log/slog" "time" + "gopkg.in/yaml.v3" + "github.com/netboxlabs/orb-agent/agent/backend" "github.com/netboxlabs/orb-agent/agent/config" "github.com/netboxlabs/orb-agent/agent/configmgr/fleet" @@ -72,8 +74,19 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st "inbox_topic", topics.Inbox, "outbox_topic", topics.Outbox) - // use the generated topics to connect over MQTT v5 - err = fleetManager.connection.Connect(ctx, jwtClaims.MqttURL, token.AccessToken, jwtClaims.AgentID, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone, cfg.OrbAgent.Labels) + connectionDetails := fleet.ConnectionDetails{ + MQTTURL: jwtClaims.MqttURL, + Token: token.AccessToken, + AgentID: jwtClaims.AgentID, + Topics: *topics, + ClientID: cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, + Zone: jwtClaims.Zone, + } + configYaml, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal agent config: %w", err) + } + err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml)) if err != nil { return err } @@ -93,7 +106,7 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st // Reconnect connectCtx := context.Background() - err = fleetManager.connection.Connect(connectCtx, jwtClaims.MqttURL, token.AccessToken, jwtClaims.AgentID, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone, cfg.OrbAgent.Labels) + err = fleetManager.connection.Connect(connectCtx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml)) if err != nil { fleetManager.logger.Error("failed to reconnect during reset", "error", err) } diff --git a/agent/configmgr/fleet/connection.go b/agent/configmgr/fleet/connection.go index cefc244..63d33d4 100644 --- a/agent/configmgr/fleet/connection.go +++ b/agent/configmgr/fleet/connection.go @@ -43,12 +43,22 @@ type TopicActions struct { Unsubscribe func(topic string) error } +// ConnectionDetails contains the details needed to connect to the MQTT broker +type ConnectionDetails struct { + MQTTURL string + Token string + AgentID string + Topics TokenResponseTopics + ClientID string + Zone string +} + // Connect connects to the MQTT broker -func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, token, agentID string, topics TokenResponseTopics, backends map[string]backend.Backend, clientID, zone string, labels map[string]string) error { +func (connection *MQTTConnection) Connect(ctx context.Context, details ConnectionDetails, backends map[string]backend.Backend, labels map[string]string, configFile string) error { // Parse the ORB URL - serverURL, err := url.Parse(fleetMQTTURL) + serverURL, err := url.Parse(details.MQTTURL) if err != nil { - connection.logger.Error("failed to parse MQTT URL", "url", fleetMQTTURL, "error", err) + connection.logger.Error("failed to parse MQTT URL", "url", details.MQTTURL, "error", err) return err } @@ -65,21 +75,21 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok _, err := cm.Subscribe(context.Background(), &paho.Subscribe{ Subscriptions: []paho.SubscribeOptions{ - {Topic: topics.Inbox, QoS: 1}, + {Topic: details.Topics.Inbox, QoS: 1}, }, }) if err != nil { - connection.logger.Error("failed to subscribe", "topic", topics.Inbox, "error", err) + connection.logger.Error("failed to subscribe", "topic", details.Topics.Inbox, "error", err) } else { - connection.logger.Info("successfully subscribed", "topic", topics.Inbox) + connection.logger.Info("successfully subscribed", "topic", details.Topics.Inbox) } // start heartbeat loop bound to the same connection-level context - go connection.heartbeater.sendHeartbeats(ctx, func() {}, topics.Heartbeat, clientID, connection.publishToTopic) + go connection.heartbeater.sendHeartbeats(ctx, func() {}, details.Topics.Heartbeat, details.ClientID, connection.publishToTopic) - connection.messaging.sendCapabilities(ctx, backends, labels, func(ctx context.Context, payload []byte) error { + connection.messaging.sendCapabilities(ctx, backends, labels, configFile, func(ctx context.Context, payload []byte) error { _, err := cm.Publish(ctx, &paho.Publish{ - Topic: topics.Capabilities, + Topic: details.Topics.Capabilities, Payload: payload, QoS: 1, Retain: false, @@ -91,7 +101,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok } connection.logger.Debug("capabilities sent", - "topic", topics.Capabilities, + "topic", details.Topics.Capabilities, "payload", string(payload), ) return nil @@ -101,7 +111,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok time.Sleep(10 * time.Second) go connection.messaging.sendGroupMembershipsRequest(ctx, func(ctx context.Context, payload []byte) error { _, err := cm.Publish(ctx, &paho.Publish{ - Topic: topics.Outbox, + Topic: details.Topics.Outbox, Payload: payload, QoS: 1, Retain: false, @@ -117,7 +127,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok connection.logger.Error("MQTT connection error", "error", err) }, ClientConfig: paho.ClientConfig{ - ClientID: clientID, + ClientID: details.ClientID, OnPublishReceived: []func(paho.PublishReceived) (bool, error){ func(pr paho.PublishReceived) (bool, error) { // Log any published messages to subscribed topics @@ -131,7 +141,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok context.Background(), pr.Packet.Payload, orgID, - agentID, + details.AgentID, TopicActions{ Subscribe: connection.subscribeToTopic, Publish: connection.publishToTopic, @@ -149,10 +159,10 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok } // Set authentication if token is provided - if token != "" { - connection.logger.Info("setting MQTT authentication", "client_id", clientID, "zone", zone) - cfg.ConnectUsername = fmt.Sprintf("%s:%s", zone, clientID) - cfg.ConnectPassword = []byte(token) + if details.Token != "" { + connection.logger.Info("setting MQTT authentication", "client_id", details.ClientID, "zone", details.Zone) + cfg.ConnectUsername = fmt.Sprintf("%s:%s", details.Zone, details.ClientID) + cfg.ConnectPassword = []byte(details.Token) } // Create and start the connection manager using the long-lived context. diff --git a/agent/configmgr/fleet/connection_test.go b/agent/configmgr/fleet/connection_test.go index 97e881e..00d676c 100644 --- a/agent/configmgr/fleet/connection_test.go +++ b/agent/configmgr/fleet/connection_test.go @@ -75,7 +75,13 @@ func TestFleetConfigManager_Connect_InvalidURL(t *testing.T) { // Act with invalid URL backends := make(map[string]backend.Backend) trt := TokenResponseTopics{Inbox: "test/topic"} - err := connection.Connect(context.Background(), "://invalid-url", "test_token", "test-agent-id", trt, backends, "test-agent-id", "test-zone", map[string]string{}) + err := connection.Connect( + context.Background(), + ConnectionDetails{MQTTURL: "://invalid-url", Token: "test_token", AgentID: "test-agent-id", Topics: trt, ClientID: "test-agent-id", Zone: "test-zone"}, + backends, + map[string]string{}, + "", + ) // Assert assert.Error(t, err) @@ -96,7 +102,12 @@ func TestFleetConfigManager_Connect_ValidURL(t *testing.T) { // Timeout after 3 seconds ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - err := connection.Connect(ctx, "mqtt://localhost:1883", "test_token", "test-agent-id", trt2, backends, "test-agent-id", "test-zone", map[string]string{}) + err := connection.Connect(ctx, + ConnectionDetails{MQTTURL: "mqtt://localhost:1883", Token: "test_token", AgentID: "test-agent-id", Topics: trt2, ClientID: "test-agent-id", Zone: "test-zone"}, + backends, + map[string]string{}, + "", + ) // Assert - we expect connection to fail since no server is running, // but URL parsing should succeed diff --git a/agent/configmgr/fleet/messages/fleet_messages.go b/agent/configmgr/fleet/messages/fleet_messages.go index 52cee80..feb5e30 100644 --- a/agent/configmgr/fleet/messages/fleet_messages.go +++ b/agent/configmgr/fleet/messages/fleet_messages.go @@ -94,6 +94,7 @@ type Capabilities struct { OrbAgent OrbAgentInfo `json:"orb_agent"` AgentLabels map[string]string `json:"agent_labels"` Backends map[string]BackendInfo `json:"backends"` + AgentConfig string `json:"agent_config"` } // GroupMemberships represents the group memberships of an agent diff --git a/agent/configmgr/fleet/to_rpc.go b/agent/configmgr/fleet/to_rpc.go index 5d8243f..17489e7 100644 --- a/agent/configmgr/fleet/to_rpc.go +++ b/agent/configmgr/fleet/to_rpc.go @@ -29,16 +29,8 @@ func (messaging *Messaging) sendGroupMembershipsRequest(ctx context.Context, pub messaging.logger.Info("group memberships request sent", "value", string(body)) } -func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, labels map[string]string, publishFunc func(ctx context.Context, payload []byte) error) { - capabilities := messages.Capabilities{ - SchemaVersion: messages.CurrentCapabilitiesSchemaVersion, - AgentLabels: labels, - OrbAgent: messages.OrbAgentInfo{ - Version: version.GetBuildVersion(), - }, - } - - capabilities.Backends = make(map[string]messages.BackendInfo) +func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, labels map[string]string, config string, publishFunc func(ctx context.Context, payload []byte) error) { + backendsInfo := make(map[string]messages.BackendInfo) for name, be := range backends { ver, err := be.Version() if err != nil { @@ -50,12 +42,22 @@ func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[s messaging.logger.Error("backend failed to retrieve capabilities, skipping", "backend", name, "error", err) continue } - capabilities.Backends[name] = messages.BackendInfo{ + backendsInfo[name] = messages.BackendInfo{ Version: ver, Data: cp, } } + capabilities := messages.Capabilities{ + SchemaVersion: messages.CurrentCapabilitiesSchemaVersion, + AgentLabels: labels, + OrbAgent: messages.OrbAgentInfo{ + Version: version.GetBuildVersion(), + }, + Backends: backendsInfo, + AgentConfig: config, + } + body, err := json.Marshal(capabilities) if err != nil { messaging.logger.Error("backend failed to marshal capabilities, skipping", "error", err) diff --git a/agent/configmgr/fleet/to_rpc_test.go b/agent/configmgr/fleet/to_rpc_test.go index 6df7d0b..d3c5a0f 100644 --- a/agent/configmgr/fleet/to_rpc_test.go +++ b/agent/configmgr/fleet/to_rpc_test.go @@ -95,9 +95,21 @@ func TestMessaging_SendCapabilities_Success(t *testing.T) { ctx := context.Background() labels := map[string]string{} + config := `orb: + config_manager: + active: local + backends: + common: + diode: + target: grpc://192.168.0.100:8080/diode + client_id: ${DIODE_CLIENT_ID} + client_secret: ${DIODE_CLIENT_SECRET} + agent_name: agent01 + snmp_discovery: + ` // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -125,6 +137,8 @@ func TestMessaging_SendCapabilities_Success(t *testing.T) { assert.Equal(t, "mqtt", backend2Info.Data["protocol"]) assert.Equal(t, "tls", backend2Info.Data["encryption"]) + assert.Equal(t, config, capabilities.AgentConfig) + // Verify all mock expectations were met mockBackend1.AssertExpectations(t) mockBackend2.AssertExpectations(t) @@ -154,6 +168,19 @@ func TestMessaging_SendCapabilities_BackendVersionError(t *testing.T) { labels := map[string]string{} + config := `orb: + config_manager: + active: local + backends: + common: + diode: + target: grpc://192.168.0.100:8080/diode + client_id: ${DIODE_CLIENT_ID} + client_secret: ${DIODE_CLIENT_SECRET} + agent_name: agent01 + snmp_discovery: + ` + var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -163,7 +190,7 @@ func TestMessaging_SendCapabilities_BackendVersionError(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -206,6 +233,18 @@ func TestMessaging_SendCapabilities_BackendCapabilitiesError(t *testing.T) { labels := map[string]string{} + config := `orb: + config_manager: + active: local + backends: + common: + diode: + target: grpc://192.168.0.100:8080/diode + client_id: ${DIODE_CLIENT_ID} + client_secret: ${DIODE_CLIENT_SECRET} + agent_name: agent01 + snmp_discovery: + ` var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -215,7 +254,7 @@ func TestMessaging_SendCapabilities_BackendCapabilitiesError(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -251,6 +290,18 @@ func TestMessaging_SendCapabilities_PublishError(t *testing.T) { labels := map[string]string{} + config := `orb: + config_manager: + active: local + backends: + common: + diode: + target: grpc://192.168.0.100:8080/diode + client_id: ${DIODE_CLIENT_ID} + client_secret: ${DIODE_CLIENT_SECRET} + agent_name: agent01 + snmp_discovery: + ` publishError := errors.New("publish failed") publishFunc := func(_ context.Context, _ []byte) error { return publishError @@ -259,7 +310,7 @@ func TestMessaging_SendCapabilities_PublishError(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert assert.Equal(t, publishError, publishFunc(ctx, []byte{})) @@ -277,6 +328,19 @@ func TestMessaging_SendCapabilities_EmptyBackends(t *testing.T) { backends := map[string]backend.Backend{} // Empty backends labels := map[string]string{} + + config := `orb: + config_manager: + active: local + backends: + common: + diode: + target: grpc://192.168.0.100:8080/diode + client_id: ${DIODE_CLIENT_ID} + client_secret: ${DIODE_CLIENT_SECRET} + agent_name: agent01 + snmp_discovery: + ` var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -286,7 +350,7 @@ func TestMessaging_SendCapabilities_EmptyBackends(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -323,6 +387,7 @@ func TestMessaging_SendCapabilities_AllBackendsFail(t *testing.T) { "backend2": mockBackend2, } + config := "" var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -332,7 +397,7 @@ func TestMessaging_SendCapabilities_AllBackendsFail(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload) @@ -372,6 +437,7 @@ func TestMessaging_SendCapabilities_CapabilitiesStructure(t *testing.T) { labels := map[string]string{} + config := "" var capturedPayload []byte publishFunc := func(_ context.Context, payload []byte) error { capturedPayload = payload @@ -381,7 +447,7 @@ func TestMessaging_SendCapabilities_CapabilitiesStructure(t *testing.T) { ctx := context.Background() // Act - messaging.sendCapabilities(ctx, backends, labels, publishFunc) + messaging.sendCapabilities(ctx, backends, labels, config, publishFunc) // Assert require.NotNil(t, capturedPayload)