Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions agent/configmgr/fleet/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ func NewMQTTConnection(logger *slog.Logger, pMgr policymgr.PolicyManager) *MQTTC
}
}

// TopicActions are the actions to take on a topic
type TopicActions struct {
Subscribe func(topic string) error
Publish func(ctx context.Context, topic string, payload []byte) error
Unsubscribe func(topic string) error
}

// Connect connects to the MQTT broker
func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, token, agentID string, topics TokenResponseTopics, backends map[string]backend.Backend, clientID, zone string, labels map[string]string) error {
// Parse the ORB URL
Expand Down Expand Up @@ -143,8 +150,11 @@ func (connection *MQTTConnection) Connect(ctx context.Context, fleetMQTTURL, tok
pr.Packet.Payload,
orgID,
agentID,
connection.subscribeToTopic,
connection.publishToTopic,
TopicActions{
Subscribe: connection.subscribeToTopic,
Publish: connection.publishToTopic,
Unsubscribe: connection.unsubscribeFromTopic,
},
)
if err != nil {
connection.logger.Error("failed to dispatch to handlers", "error", err)
Expand Down Expand Up @@ -194,6 +204,13 @@ func (connection *MQTTConnection) subscribeToTopic(topic string) error {
return err
}

func (connection *MQTTConnection) unsubscribeFromTopic(topic string) error {
_, err := connection.connectionManager.Unsubscribe(context.Background(), &paho.Unsubscribe{
Topics: []string{topic},
})
return err
}

func (connection *MQTTConnection) publishToTopic(ctx context.Context, topic string, payload []byte) error {
connection.logger.Debug("publishing to topic", "topic", topic, "payload", string(payload))
_, err := connection.connectionManager.Publish(ctx, &paho.Publish{
Expand Down
45 changes: 43 additions & 2 deletions agent/configmgr/fleet/from_rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"log/slog"

"github.com/netboxlabs/orb-agent/agent/backend"
"github.com/netboxlabs/orb-agent/agent/config"
"github.com/netboxlabs/orb-agent/agent/configmgr/fleet/messages"
"github.com/netboxlabs/orb-agent/agent/policymgr"
Expand All @@ -25,7 +26,7 @@ func NewMessaging(logger *slog.Logger, policyManager policymgr.PolicyManager) *M
}

// DispatchToHandlers dispatches the message to the appropriate handler
func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []byte, orgID string, agentID string, subscribeToTopic func(topic string) error, publishToTopic func(ctx context.Context, topic string, payload []byte) error) error {
func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []byte, orgID string, agentID string, topicActions TopicActions) error {
var rpc messages.RPC
if err := json.Unmarshal(payload, &rpc); err != nil {
messaging.logger.Error("failed to unmarshal RPC", "error", err)
Expand All @@ -46,14 +47,21 @@ func (messaging *Messaging) DispatchToHandlers(ctx context.Context, payload []by
messaging.logger.Error("failed to unmarshal payload", "error", err)
return err
}
messaging.handleGroupMemberships(ctx, groupMemberships.Payload, orgID, agentID, subscribeToTopic, publishToTopic)
messaging.handleGroupMemberships(ctx, groupMemberships.Payload, orgID, agentID, topicActions.Subscribe, topicActions.Publish)
case messages.AgentPolicyRPCFunc:
agentPolicies := messages.AgentPolicyRPC{}
if err := json.Unmarshal(payload, &agentPolicies); err != nil {
messaging.logger.Error("failed to unmarshal payload", "error", err)
return err
}
messaging.handleAgentPolicies(agentPolicies.Payload, agentPolicies.FullList)
case messages.GroupRemovedRPCFunc:
groupRemoved := messages.GroupRemovedRPC{}
if err := json.Unmarshal(payload, &groupRemoved); err != nil {
messaging.logger.Error("failed to unmarshal payload", "error", err)
return err
}
messaging.handleAgentGroupRemoval(groupRemoved.Payload, topicActions.Unsubscribe)
default:
messaging.logger.Debug("unknown rpc function", "func", rpc.Func)
}
Expand Down Expand Up @@ -125,3 +133,36 @@ func (messaging *Messaging) handleAgentPolicies(rpc []messages.AgentPolicyRPCPay
}
messaging.logger.Info("successfully processed agent policies", "count", len(rpc))
}

func (messaging *Messaging) handleAgentGroupRemoval(rpc messages.GroupRemovedRPCPayload, unsubscribeFromTopic func(topic string) error) {
err := unsubscribeFromTopic(rpc.AgentGroupID)
if err != nil {
messaging.logger.Error("failed to unsubscribe from group topic", "error", err)
return
}

policies, err := messaging.policyManager.GetRepo().GetAll()
if err != nil {
return
}

for _, policy := range policies {
delete(policy.GroupIDs, rpc.AgentGroupID)

if len(policy.GroupIDs) == 0 {
messaging.logger.Info("policy no longer used by any group, removing", "policy_id", policy.ID, "policy_name", policy.Name)

err = messaging.policyManager.RemovePolicy(policy.ID, policy.Name, policy.Backend)
if err != nil {
messaging.logger.Warn("failed to remove a policy, ignoring", "policy_id", policy.ID, "policy_name", policy.Name, "error", err)
continue
}
} else {
for _, datasetID := range rpc.Datasets {
if backend.HaveBackend(policy.Backend) {
messaging.policyManager.RemovePolicyDataset(policy.ID, datasetID, backend.GetBackend(policy.Backend))
}
}
}
}
}
Loading
Loading