Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a Hook to reload an individual client session #402

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const (
StoredInflightMessages
StoredRetainedMessages
StoredSysInfo
StoredClientByID
)

var (
Expand Down Expand Up @@ -114,6 +115,7 @@ type Hook interface {
StoredInflightMessages() ([]storage.Message, error)
StoredRetainedMessages() ([]storage.Message, error)
StoredSysInfo() (storage.SystemInfo, error)
StoredClientByID(id string, username []byte) (string, []storage.Subscription, []storage.Message, error)
}

// HookOptions contains values which are inherited from the server on initialisation.
Expand Down Expand Up @@ -679,6 +681,25 @@ func (h *Hooks) OnACLCheck(cl *Client, topic string, write bool) bool {
return false
}

// StoredClientByID returns the state of the stored client with the given session ID, if any.
func (h *Hooks) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) {
for _, hook := range h.GetAll() {
if hook.Provides(StoredClientByID) {
oldRemote, subs, msgs, err = hook.StoredClientByID(id, username)
if err != nil {
h.Log.Error("failed to load client by ID", "error", err, "hook", hook.ID())
return
}

if oldRemote != "" && err == nil {
return
}
}
}

return
}

// HookBase provides a set of default methods for each hook. It should be embedded in
// all hooks.
type HookBase struct {
Expand Down Expand Up @@ -859,3 +880,8 @@ func (h *HookBase) StoredRetainedMessages() (v []storage.Message, err error) {
func (h *HookBase) StoredSysInfo() (v storage.SystemInfo, err error) {
return
}

// StoredClientByID returns the state of the stored client with the given session ID, if any.
func (h *HookBase) StoredClientByID(id string, username []byte) (oldRemote string, subs []storage.Subscription, msgs []storage.Message, err error) {
return
}
84 changes: 66 additions & 18 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ func (s *Server) attachClient(cl *Client, listener string) error {
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)

if s.hooks.Provides(StoredClientByID) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed feelings about this. I like the overall idea of the implementation but I'm not sure we should be adding any hook-specific logic to the server code as it defeats the objective of having hooks. In this case the hook is mostly being used as a feature flag, whereas the hook logic should be stored inside the hook.

Maybe a better option here would be to enhance OnDisconnect so that it can overwrite the value of expire:

expire = s.hooks.OnDisconnect(cl, err, expire)

and modify the method signature accordingly:

func (h *Hook) OnDisconnect(cl *mqtt.Client, _ error, expire bool) bool {

// Hooks are capable of reloading a persistent client session, so I can forget it
expire = true
}

if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
cl.ClearInflights()
s.UnsubscribeClient(cl)
Expand Down Expand Up @@ -596,6 +601,42 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
return true // [MQTT-3.2.2-3]
}

// Look up a stored client that's not in memory yet:
if s.hooks.Provides(StoredClientByID) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue here. Possible solution would be to add this logic to the OnSessionEstablish hook and overwrite the values of Client.State directly, but without adding the client to the server. This will avoid the client being picked up in inheritClientSession and overwriting the new networked client values.

oldRemote, subs, msgs, err := s.hooks.StoredClientByID(cl.ID, cl.Properties.Username)
if err == nil && oldRemote != "" {
// Instantiate in-flight messages to deliver:
if len(msgs) > 0 {
inf := NewInflights()
for _, msg := range msgs {
inf.Set(msg.ToPacket())
}
cl.State.Inflight = inf
}

// Instantiate stored subscriptions:
for _, sub := range subs {
sb := packets.Subscription{
Filter: sub.Filter,
RetainHandling: sub.RetainHandling,
Qos: sub.Qos,
RetainAsPublished: sub.RetainAsPublished,
NoLocal: sub.NoLocal,
Identifier: sub.Identifier,
}
existed := !s.Topics.Subscribe(cl.ID, sb) // [MQTT-3.8.4-3]
if !existed {
atomic.AddInt64(&s.Info.Subscriptions, 1)
}
cl.State.Subscriptions.Add(sb.Filter, sb)
}

s.Log.Debug("session taken over (persistent)", "client", cl.ID, "old_remote", oldRemote, "new_remote", cl.Net.Remote)

return true
}
}

if atomic.LoadInt64(&s.Info.ClientsConnected) > atomic.LoadInt64(&s.Info.ClientsMaximum) {
atomic.AddInt64(&s.Info.ClientsMaximum, 1)
}
Expand Down Expand Up @@ -1014,6 +1055,7 @@ func (s *Server) publishToSubscribers(pk packets.Packet) {
}
}

// publishToClient delivers a published message to a single subscriber client.
func (s *Server) publishToClient(cl *Client, sub packets.Subscription, pk packets.Packet) (packets.Packet, error) {
if sub.NoLocal && pk.Origin == cl.ID {
return pk, nil // [MQTT-3.8.3-3]
Expand Down Expand Up @@ -1636,24 +1678,7 @@ func (s *Server) loadSubscriptions(v []storage.Subscription) {
// loadClients restores clients from the datastore.
func (s *Server) loadClients(v []storage.Client) {
for _, c := range v {
cl := s.NewClient(nil, c.Listener, c.ID, false)
cl.Properties.Username = c.Username
cl.Properties.Clean = c.Clean
cl.Properties.ProtocolVersion = c.ProtocolVersion
cl.Properties.Props = packets.Properties{
SessionExpiryInterval: c.Properties.SessionExpiryInterval,
SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag,
AuthenticationMethod: c.Properties.AuthenticationMethod,
AuthenticationData: c.Properties.AuthenticationData,
RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag,
RequestProblemInfo: c.Properties.RequestProblemInfo,
RequestResponseInfo: c.Properties.RequestResponseInfo,
ReceiveMaximum: c.Properties.ReceiveMaximum,
TopicAliasMaximum: c.Properties.TopicAliasMaximum,
User: c.Properties.User,
MaximumPacketSize: c.Properties.MaximumPacketSize,
}
cl.Properties.Will = Will(c.Will)
cl := s.newClientFromStorage(&c)

// cancel the context, update cl.State such as disconnected time and stopCause.
cl.Stop(packets.ErrServerShuttingDown)
Expand All @@ -1669,6 +1694,29 @@ func (s *Server) loadClients(v []storage.Client) {
}
}

// newClientFromStorage creates a Client from a storage.Client.
func (s *Server) newClientFromStorage(c *storage.Client) *Client {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against it, but what's the intention behind extracting this into an additional method if it's only called from the same place?

cl := s.NewClient(nil, c.Listener, c.ID, false)
cl.Properties.Username = c.Username
cl.Properties.Clean = c.Clean
cl.Properties.ProtocolVersion = c.ProtocolVersion
cl.Properties.Props = packets.Properties{
SessionExpiryInterval: c.Properties.SessionExpiryInterval,
SessionExpiryIntervalFlag: c.Properties.SessionExpiryIntervalFlag,
AuthenticationMethod: c.Properties.AuthenticationMethod,
AuthenticationData: c.Properties.AuthenticationData,
RequestProblemInfoFlag: c.Properties.RequestProblemInfoFlag,
RequestProblemInfo: c.Properties.RequestProblemInfo,
RequestResponseInfo: c.Properties.RequestResponseInfo,
ReceiveMaximum: c.Properties.ReceiveMaximum,
TopicAliasMaximum: c.Properties.TopicAliasMaximum,
User: c.Properties.User,
MaximumPacketSize: c.Properties.MaximumPacketSize,
}
cl.Properties.Will = Will(c.Will)
return cl
}

// loadInflight restores inflight messages from the datastore.
func (s *Server) loadInflight(v []storage.Message) {
for _, msg := range v {
Expand Down
Loading