Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions agent/configmgr/fleet.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,30 @@ import (
var _ Manager = (*fleetConfigManager)(nil)

type fleetConfigManager struct {
logger *slog.Logger
connection *fleet.MQTTConnection
authTokenManager *fleet.AuthTokenManager
resetChan chan struct{}
backendState backend.StateRetriever
policyManager policymgr.PolicyManager
otlpBridge *otlpbridge.BridgeServer
logger *slog.Logger
connection *fleet.MQTTConnection
authTokenManager *fleet.AuthTokenManager
resetChan chan struct{}
reconnectChan chan struct{}
backendState backend.StateRetriever
policyManager policymgr.PolicyManager
otlpBridge *otlpbridge.BridgeServer
config config.Config
backends map[string]backend.Backend
labels map[string]string
configYaml string
connectionDetails fleet.ConnectionDetails
}

func newFleetConfigManager(logger *slog.Logger, pMgr policymgr.PolicyManager, backendState backend.StateRetriever) *fleetConfigManager {
resetChan := make(chan struct{}, 1)
reconnectChan := make(chan struct{}, 1)
return &fleetConfigManager{
logger: logger,
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, backendState),
connection: fleet.NewMQTTConnection(logger, pMgr, resetChan, reconnectChan, backendState),
authTokenManager: fleet.NewAuthTokenManager(logger),
resetChan: resetChan,
reconnectChan: reconnectChan,
backendState: backendState,
policyManager: pMgr,
}
Expand Down Expand Up @@ -106,6 +114,14 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
if err != nil {
return fmt.Errorf("failed to convert config to safe string: %w", err)
}

// Store connection state for reconnection
fleetManager.config = cfg
fleetManager.backends = backends
fleetManager.labels = cfg.OrbAgent.Labels
fleetManager.configYaml = string(configYaml)
fleetManager.connectionDetails = connectionDetails

err = fleetManager.connection.Connect(ctx, connectionDetails, backends, cfg.OrbAgent.Labels, string(configYaml))
if err != nil {
return err
Expand Down Expand Up @@ -158,6 +174,65 @@ func (fleetManager *fleetConfigManager) Start(cfg config.Config, backends map[st
fleetManager.logger.Info("OTLP bridge bound to Fleet MQTT", slog.String("topic", topics.Ingest))
})

// Start goroutine to handle reconnect requests (JWT refresh)
go func() {
for range fleetManager.reconnectChan {
fleetManager.logger.Info("JWT refresh and reconnection requested")
if err := fleetManager.refreshAndReconnect(ctx, timeout); err != nil {
fleetManager.logger.Error("failed to refresh and reconnect", "error", err)
}
}
}()

return nil
}

// refreshAndReconnect refreshes the JWT token and reconnects to MQTT
func (fleetManager *fleetConfigManager) refreshAndReconnect(ctx context.Context, timeout time.Duration) error {
// Refresh JWT token
token, err := fleetManager.authTokenManager.RefreshToken(ctx)
if err != nil {
return fmt.Errorf("failed to refresh token: %w", err)
}

// Parse new JWT claims
jwtClaims, err := fleet.ParseJWTClaims(token.AccessToken)
if err != nil {
return fmt.Errorf("failed to parse JWT claims: %w", err)
}

// Regenerate topics
topics, err := fleet.GenerateTopicsFromTemplate(jwtClaims)
if err != nil {
return fmt.Errorf("failed to generate topics: %w", err)
}

fleetManager.logger.Info("refreshed JWT and generated new topics",
"heartbeat_topic", topics.Heartbeat,
"capabilities_topic", topics.Capabilities,
"inbox_topic", topics.Inbox,
"outbox_topic", topics.Outbox)

// Update connection details
newConnectionDetails := fleet.ConnectionDetails{
MQTTURL: jwtClaims.MqttURL,
Token: token.AccessToken,
AgentID: jwtClaims.AgentID,
Topics: *topics,
ClientID: fleetManager.config.OrbAgent.ConfigManager.Sources.Fleet.ClientID,
Zone: jwtClaims.Zone,
}

// Store updated connection details
fleetManager.connectionDetails = newConnectionDetails

// Reconnect with new token
err = fleetManager.connection.Reconnect(ctx, newConnectionDetails, fleetManager.backends, fleetManager.labels, fleetManager.configYaml, timeout)
if err != nil {
return fmt.Errorf("failed to reconnect: %w", err)
}

fleetManager.logger.Info("successfully refreshed JWT and reconnected")
return nil
}

Expand Down
40 changes: 39 additions & 1 deletion agent/configmgr/fleet/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@ import (

// AuthTokenManager manages auth tokens
type AuthTokenManager struct {
logger *slog.Logger
logger *slog.Logger
tokenURL string
skipTLS bool
timeout time.Duration
clientID string
clientSecret string
lastToken *TokenResponse
tokenExpiresAt time.Time
}

// NewAuthTokenManager creates a new AuthTokenManager
Expand Down Expand Up @@ -46,6 +53,13 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
return nil, fmt.Errorf("client secret cannot be empty")
}

// Store credentials for future refresh
fleetManager.tokenURL = tokenURL
fleetManager.skipTLS = skipTLS
fleetManager.timeout = timeout
fleetManager.clientID = clientID
fleetManager.clientSecret = clientSecret

fleetManager.logger.Debug("requesting access token", "token_url", tokenURL, "client_id", clientID)

scopes := []string{
Expand Down Expand Up @@ -121,5 +135,29 @@ func (fleetManager *AuthTokenManager) GetToken(ctx context.Context, tokenURL str
"expires_in", TokenResponse.ExpiresIn,
"mqtt_url", TokenResponse.MQTTURL)

// Store token and calculate expiration time (with 5-minute buffer)
fleetManager.lastToken = &TokenResponse
if TokenResponse.ExpiresIn > 0 {
fleetManager.tokenExpiresAt = time.Now().Add(time.Duration(TokenResponse.ExpiresIn)*time.Second - 5*time.Minute)
}

return &TokenResponse, nil
}

// RefreshToken refreshes the auth token using stored credentials
func (fleetManager *AuthTokenManager) RefreshToken(ctx context.Context) (*TokenResponse, error) {
if fleetManager.tokenURL == "" {
return nil, fmt.Errorf("cannot refresh token: credentials not initialized")
}

fleetManager.logger.Info("refreshing JWT token")
return fleetManager.GetToken(ctx, fleetManager.tokenURL, fleetManager.skipTLS, fleetManager.timeout, fleetManager.clientID, fleetManager.clientSecret)
}

// IsTokenExpired checks if the current token is expired or will expire soon
func (fleetManager *AuthTokenManager) IsTokenExpired() bool {
if fleetManager.lastToken == nil {
return true
}
return time.Now().After(fleetManager.tokenExpiresAt)
}
111 changes: 99 additions & 12 deletions agent/configmgr/fleet/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@ import (

// MQTTConnection manages the MQTT connection
type MQTTConnection struct {
logger *slog.Logger
connectionManager *autopaho.ConnectionManager
heartbeater *heartbeater
messaging *Messaging
resetChan chan struct{}
onReadyHooks []func(cm *autopaho.ConnectionManager, topics TokenResponseTopics)
connectionTopics TokenResponseTopics
logger *slog.Logger
connectionManager *autopaho.ConnectionManager
heartbeater *heartbeater
messaging *Messaging
resetChan chan struct{}
onReadyHooks []func(cm *autopaho.ConnectionManager, topics TokenResponseTopics)
connectionTopics TokenResponseTopics
reconnectChan chan struct{}
capabilitiesFailCount int
groupMembershipFailCount int
heartbeatFailCount int
}

// NewMQTTConnection creates a new MQTTConnection
func NewMQTTConnection(logger *slog.Logger, pMgr policymgr.PolicyManager, resetChan chan struct{}, backendState backend.StateRetriever) *MQTTConnection {
func NewMQTTConnection(logger *slog.Logger, pMgr policymgr.PolicyManager, resetChan chan struct{}, reconnectChan chan struct{}, backendState backend.StateRetriever) *MQTTConnection {
groupManager := newGroupManager()
return &MQTTConnection{
connectionManager: nil,
Expand All @@ -36,6 +40,7 @@ func NewMQTTConnection(logger *slog.Logger, pMgr policymgr.PolicyManager, resetC
messaging: NewMessaging(logger, pMgr, resetChan, &groupManager),
resetChan: resetChan,
onReadyHooks: make([]func(cm *autopaho.ConnectionManager, topics TokenResponseTopics), 0),
reconnectChan: reconnectChan,
}
}

Expand Down Expand Up @@ -103,7 +108,23 @@ func (connection *MQTTConnection) Connect(ctx context.Context, details Connectio
}

// start heartbeat loop bound to the same connection-level context
go connection.heartbeater.sendHeartbeats(ctx, func() {}, details.Topics.Heartbeat, details.ClientID, connection.publishToTopic)
go connection.heartbeater.sendHeartbeats(ctx, func() {}, details.Topics.Heartbeat, details.ClientID, connection.publishToTopic, func() {
// Track heartbeat failures
connection.heartbeatFailCount++
connection.logger.Error("heartbeat publish failed",
"fail_count", connection.heartbeatFailCount)

// After 5 consecutive failures, trigger reconnect
if connection.heartbeatFailCount >= 5 {
connection.logger.Warn("heartbeat publish failed 5 times, triggering JWT refresh and reconnect")
select {
case connection.reconnectChan <- struct{}{}:
default:
connection.logger.Debug("reconnect already in progress")
}
connection.heartbeatFailCount = 0
}
})

connection.messaging.sendCapabilities(ctx, backends, labels, configFile, func(ctx context.Context, payload []byte) error {
_, err := cm.Publish(ctx, &paho.Publish{
Expand All @@ -113,11 +134,26 @@ func (connection *MQTTConnection) Connect(ctx context.Context, details Connectio
Retain: false,
})
if err != nil {
// TODO: reconnect?
connection.logger.Error("failed to publish capabilities", "error", err)
connection.capabilitiesFailCount++
connection.logger.Error("failed to publish capabilities",
"error", err,
"fail_count", connection.capabilitiesFailCount)

// After 1 retry (2 failures), trigger reconnect
if connection.capabilitiesFailCount >= 2 {
connection.logger.Warn("capabilities publish failed twice, triggering JWT refresh and reconnect")
select {
case connection.reconnectChan <- struct{}{}:
default:
connection.logger.Debug("reconnect already in progress")
}
connection.capabilitiesFailCount = 0
}
return err
}

// Reset counter on success
connection.capabilitiesFailCount = 0
connection.logger.Debug("capabilities sent",
"topic", details.Topics.Capabilities,
"payload", string(payload),
Expand All @@ -135,9 +171,26 @@ func (connection *MQTTConnection) Connect(ctx context.Context, details Connectio
Retain: false,
})
if err != nil {
connection.logger.Error("failed to publish group memberships request", "error", err)
connection.groupMembershipFailCount++
connection.logger.Error("failed to publish group memberships request",
"error", err,
"fail_count", connection.groupMembershipFailCount)

// After 1 retry (2 failures), trigger reconnect
if connection.groupMembershipFailCount >= 2 {
connection.logger.Warn("group membership publish failed twice, triggering JWT refresh and reconnect")
select {
case connection.reconnectChan <- struct{}{}:
default:
connection.logger.Debug("reconnect already in progress")
}
connection.groupMembershipFailCount = 0
}
return err
}

// Reset counter on success
connection.groupMembershipFailCount = 0
return nil
})
},
Expand Down Expand Up @@ -205,6 +258,37 @@ func (connection *MQTTConnection) Connect(ctx context.Context, details Connectio
return nil
}

// Reconnect reconnects to the MQTT broker with new connection details (e.g., refreshed JWT)
func (connection *MQTTConnection) Reconnect(ctx context.Context, details ConnectionDetails, backends map[string]backend.Backend, labels map[string]string, configFile string, timeout time.Duration) error {
connection.logger.Info("reconnecting to MQTT broker with refreshed credentials")

// Disconnect the existing connection
if connection.connectionManager != nil {
disconnectCtx, cancel := context.WithTimeout(ctx, timeout)
connection.heartbeater.stop(details.Topics.Heartbeat, connection.publishToTopic)
err := connection.connectionManager.Disconnect(disconnectCtx)
cancel()
if err != nil {
connection.logger.Error("failed to disconnect during reconnect", "error", err)
// Continue anyway to try to establish new connection
}
}

// Reset failure counters
connection.capabilitiesFailCount = 0
connection.groupMembershipFailCount = 0
connection.heartbeatFailCount = 0

// Connect with new details
err := connection.Connect(ctx, details, backends, labels, configFile)
if err != nil {
return fmt.Errorf("failed to connect during reconnect: %w", err)
}

connection.logger.Info("successfully reconnected to MQTT broker")
return nil
}

// Disconnect disconnects from the MQTT broker
func (connection *MQTTConnection) Disconnect(ctx context.Context, heartbeatTopic string) error {
connection.heartbeater.stop(heartbeatTopic, connection.publishToTopic)
Expand Down Expand Up @@ -239,5 +323,8 @@ func (connection *MQTTConnection) publishToTopic(ctx context.Context, topic stri
connection.logger.Error("failed to publish to topic", "topic", topic, "error", err)
return err
}
// Reset heartbeat failure counter on successful publish
// (heartbeats use this function, so successful publish means connection is ok)
connection.heartbeatFailCount = 0
return nil
}
6 changes: 4 additions & 2 deletions agent/configmgr/fleet/connection_hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func (noopBackendState) Get() map[string]*backend.State { return map[string]*bac
func TestAddOnReadyHook_RegistersHook(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
reset := make(chan struct{}, 1)
conn := NewMQTTConnection(logger, noopPM{}, reset, noopBackendState{})
reconnect := make(chan struct{}, 1)
conn := NewMQTTConnection(logger, noopPM{}, reset, reconnect, noopBackendState{})

if len(conn.onReadyHooks) != 0 {
t.Fatalf("expected 0 hooks initially, got %d", len(conn.onReadyHooks))
Expand All @@ -48,7 +49,8 @@ func TestAddOnReadyHook_RegistersHook(t *testing.T) {
func TestConnect_StoresTopicsBeforeConnecting(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
reset := make(chan struct{}, 1)
conn := NewMQTTConnection(logger, noopPM{}, reset, noopBackendState{})
reconnect := make(chan struct{}, 1)
conn := NewMQTTConnection(logger, noopPM{}, reset, reconnect, noopBackendState{})

details := ConnectionDetails{
MQTTURL: "mqtt://localhost:1883",
Expand Down
6 changes: 4 additions & 2 deletions agent/configmgr/fleet/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func TestFleetConfigManager_Connect_InvalidURL(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
mockPMgr := &mockPolicyManagerForFleet{}
resetChan := make(chan struct{}, 1)
connection := NewMQTTConnection(logger, mockPMgr, resetChan, &mockBackendState{})
reconnectChan := make(chan struct{}, 1)
connection := NewMQTTConnection(logger, mockPMgr, resetChan, reconnectChan, &mockBackendState{})

// Act with invalid URL
backends := make(map[string]backend.Backend)
Expand All @@ -93,7 +94,8 @@ func TestFleetConfigManager_Connect_ValidURL(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
mockPMgr := &mockPolicyManagerForFleet{}
resetChan := make(chan struct{}, 1)
connection := NewMQTTConnection(logger, mockPMgr, resetChan, &mockBackendState{})
reconnectChan := make(chan struct{}, 1)
connection := NewMQTTConnection(logger, mockPMgr, resetChan, reconnectChan, &mockBackendState{})

// Act with valid URL but don't expect successful connection
// since we don't have a real MQTT server
Expand Down
Loading
Loading