Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions agent/configmgr/fleet.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"log/slog"
"time"

"gopkg.in/yaml.v3"

"github.com/netboxlabs/orb-agent/agent/backend"
"github.com/netboxlabs/orb-agent/agent/config"
"github.com/netboxlabs/orb-agent/agent/configmgr/fleet"
Expand Down Expand Up @@ -72,8 +74,19 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
"inbox_topic", topics.Inbox,
"outbox_topic", topics.Outbox)

// use the generated topics to connect over MQTT v5
err = fleetManager.connection.Connect(ctx, jwtClaims.MqttURL, token.AccessToken, jwtClaims.AgentID, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone, cfg.OrbAgent.Labels)
connectionDetails := fleet.ConnectionDetails{
MQTTURL: jwtClaims.MqttURL,
Token: token.AccessToken,
AgentID: jwtClaims.AgentID,
Topics: *topics,
ClientID: cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID,
Zone: jwtClaims.Zone,
}
configYaml, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal agent config: %w", err)
}
err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
if err != nil {
return err
}
Expand All @@ -93,7 +106,7 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st

// Reconnect
connectCtx := context.Background()
err = fleetManager.connection.Connect(connectCtx, jwtClaims.MqttURL, token.AccessToken, jwtClaims.AgentID, *topics, backends, cfg.OrbAgent.ConfigManager.Sources.Fleet.ClientID, jwtClaims.Zone, cfg.OrbAgent.Labels)
err = fleetManager.connection.Connect(connectCtx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
if err != nil {
fleetManager.logger.Error("failed to reconnect during reset", "error", err)
}
Expand Down
44 changes: 27 additions & 17 deletions agent/configmgr/fleet/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ type TopicActions struct {
Unsubscribe func(topic string) error
}

// ConnectionDetails contains the details needed to connect to the MQTT broker
type ConnectionDetails struct {
MQTTURL string
Token string
AgentID string
Topics TokenResponseTopics
ClientID string
Zone string
}

// Connect connects to the MQTT broker
func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, token, agentID string, topics TokenResponseTopics, backends map[string]backend.Backend, clientID, zone string, labels map[string]string) error {
func (connection *MQTTConnection) Connect(ctx context.Context, details ConnectionDetails, backends map[string]backend.Backend, labels map[string]string, configFile string) error {
// Parse the ORB URL
serverURL, err := url.Parse(fleetMQTTURL)
serverURL, err := url.Parse(details.MQTTURL)
if err != nil {
connection.logger.Error("failed to parse MQTT URL", "url", fleetMQTTURL, "error", err)
connection.logger.Error("failed to parse MQTT URL", "url", details.MQTTURL, "error", err)
return err
}

Expand All @@ -65,21 +75,21 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok

_, err := cm.Subscribe(context.Background(), &paho.Subscribe{
Subscriptions: []paho.SubscribeOptions{
{Topic: topics.Inbox, QoS: 1},
{Topic: details.Topics.Inbox, QoS: 1},
},
})
if err != nil {
connection.logger.Error("failed to subscribe", "topic", topics.Inbox, "error", err)
connection.logger.Error("failed to subscribe", "topic", details.Topics.Inbox, "error", err)
} else {
connection.logger.Info("successfully subscribed", "topic", topics.Inbox)
connection.logger.Info("successfully subscribed", "topic", details.Topics.Inbox)
}

// start heartbeat loop bound to the same connection-level context
go connection.heartbeater.sendHeartbeats(ctx, func() {}, topics.Heartbeat, clientID, connection.publishToTopic)
go connection.heartbeater.sendHeartbeats(ctx, func() {}, details.Topics.Heartbeat, details.ClientID, connection.publishToTopic)

connection.messaging.sendCapabilities(ctx, backends, labels, func(ctx context.Context, payload []byte) error {
connection.messaging.sendCapabilities(ctx, backends, labels, configFile, func(ctx context.Context, payload []byte) error {
_, err := cm.Publish(ctx, &paho.Publish{
Topic: topics.Capabilities,
Topic: details.Topics.Capabilities,
Payload: payload,
QoS: 1,
Retain: false,
Expand All @@ -91,7 +101,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
}

connection.logger.Debug("capabilities sent",
"topic", topics.Capabilities,
"topic", details.Topics.Capabilities,
"payload", string(payload),
)
return nil
Expand All @@ -101,7 +111,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
time.Sleep(10 * time.Second)
go connection.messaging.sendGroupMembershipsRequest(ctx, func(ctx context.Context, payload []byte) error {
_, err := cm.Publish(ctx, &paho.Publish{
Topic: topics.Outbox,
Topic: details.Topics.Outbox,
Payload: payload,
QoS: 1,
Retain: false,
Expand All @@ -117,7 +127,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
connection.logger.Error("MQTT connection error", "error", err)
},
ClientConfig: paho.ClientConfig{
ClientID: clientID,
ClientID: details.ClientID,
OnPublishReceived: []func(paho.PublishReceived) (bool, error){
func(pr paho.PublishReceived) (bool, error) {
// Log any published messages to subscribed topics
Expand All @@ -131,7 +141,7 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
context.Background(),
pr.Packet.Payload,
orgID,
agentID,
details.AgentID,
TopicActions{
Subscribe: connection.subscribeToTopic,
Publish: connection.publishToTopic,
Expand All @@ -149,10 +159,10 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
}

// Set authentication if token is provided
if token != "" {
connection.logger.Info("setting MQTT authentication", "client_id", clientID, "zone", zone)
cfg.ConnectUsername = fmt.Sprintf("%s:%s", zone, clientID)
cfg.ConnectPassword = []byte(token)
if details.Token != "" {
connection.logger.Info("setting MQTT authentication", "client_id", details.ClientID, "zone", details.Zone)
cfg.ConnectUsername = fmt.Sprintf("%s:%s", details.Zone, details.ClientID)
cfg.ConnectPassword = []byte(details.Token)
}

// Create and start the connection manager using the long-lived context.
Expand Down
15 changes: 13 additions & 2 deletions agent/configmgr/fleet/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ func TestFleetConfigManager_Connect_InvalidURL(t *testing.T) {
// Act with invalid URL
backends := make(map[string]backend.Backend)
trt := TokenResponseTopics{Inbox: "test/topic"}
err := connection.Connect(context.Background(), "://invalid-url", "test_token", "test-agent-id", trt, backends, "test-agent-id", "test-zone", map[string]string{})
err := connection.Connect(
context.Background(),
ConnectionDetails{MQTTURL: "://invalid-url", Token: "test_token", AgentID: "test-agent-id", Topics: trt, ClientID: "test-agent-id", Zone: "test-zone"},
backends,
map[string]string{},
"",
)

// Assert
assert.Error(t, err)
Expand All @@ -96,7 +102,12 @@ func TestFleetConfigManager_Connect_ValidURL(t *testing.T) {
// Timeout after 3 seconds
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := connection.Connect(ctx, "mqtt://localhost:1883", "test_token", "test-agent-id", trt2, backends, "test-agent-id", "test-zone", map[string]string{})
err := connection.Connect(ctx,
ConnectionDetails{MQTTURL: "mqtt://localhost:1883", Token: "test_token", AgentID: "test-agent-id", Topics: trt2, ClientID: "test-agent-id", Zone: "test-zone"},
backends,
map[string]string{},
"",
)

// Assert - we expect connection to fail since no server is running,
// but URL parsing should succeed
Expand Down
1 change: 1 addition & 0 deletions agent/configmgr/fleet/messages/fleet_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ type Capabilities struct {
OrbAgent OrbAgentInfo `json:"orb_agent"`
AgentLabels map[string]string `json:"agent_labels"`
Backends map[string]BackendInfo `json:"backends"`
AgentConfig string `json:"agent_config"`
}

// GroupMemberships represents the group memberships of an agent
Expand Down
24 changes: 13 additions & 11 deletions agent/configmgr/fleet/to_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,8 @@ func (messaging *Messaging) sendGroupMembershipsRequest(ctx context.Context, pub
messaging.logger.Info("group memberships request sent", "value", string(body))
}

func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, labels map[string]string, publishFunc func(ctx context.Context, payload []byte) error) {
capabilities := messages.Capabilities{
SchemaVersion: messages.CurrentCapabilitiesSchemaVersion,
AgentLabels: labels,
OrbAgent: messages.OrbAgentInfo{
Version: version.GetBuildVersion(),
},
}

capabilities.Backends = make(map[string]messages.BackendInfo)
func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[string]backend.Backend, labels map[string]string, config string, publishFunc func(ctx context.Context, payload []byte) error) {
backendsInfo := make(map[string]messages.BackendInfo)
for name, be := range backends {
ver, err := be.Version()
if err != nil {
Expand All @@ -50,12 +42,22 @@ func (messaging *Messaging) sendCapabilities(ctx context.Context, backends map[s
messaging.logger.Error("backend failed to retrieve capabilities, skipping", "backend", name, "error", err)
continue
}
capabilities.Backends[name] = messages.BackendInfo{
backendsInfo[name] = messages.BackendInfo{
Version: ver,
Data: cp,
}
}

capabilities := messages.Capabilities{
SchemaVersion: messages.CurrentCapabilitiesSchemaVersion,
AgentLabels: labels,
OrbAgent: messages.OrbAgentInfo{
Version: version.GetBuildVersion(),
},
Backends: backendsInfo,
AgentConfig: config,
}

body, err := json.Marshal(capabilities)
if err != nil {
messaging.logger.Error("backend failed to marshal capabilities, skipping", "error", err)
Expand Down
Loading