From f734b60762d050177d4ca4cb90370adf24412f41 Mon Sep 17 00:00:00 2001 From: Florimond Husquinet Date: Thu, 16 Nov 2023 22:11:00 +0100 Subject: [PATCH] First untested attempt at a Limiter by size --- internal/network/mqtt/buffer.go | 2 +- internal/network/mqtt/mqtt.go | 4 +- internal/provider/storage/storage.go | 63 ++++++++++++++++++++++++++-- internal/service/history/history.go | 4 +- 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/internal/network/mqtt/buffer.go b/internal/network/mqtt/buffer.go index 4b9f5608..591a0749 100644 --- a/internal/network/mqtt/buffer.go +++ b/internal/network/mqtt/buffer.go @@ -23,7 +23,7 @@ const smallBufferSize = 64 const maxInt = int(^uint(0) >> 1) // buffers are reusable fixed-side buffers for faster encoding. -var buffers = newBufferPool(maxMessageSize) +var buffers = newBufferPool(MaxMessageSize) // bufferPool represents a thread safe buffer pool type bufferPool struct { diff --git a/internal/network/mqtt/mqtt.go b/internal/network/mqtt/mqtt.go index 266374d5..110baf32 100644 --- a/internal/network/mqtt/mqtt.go +++ b/internal/network/mqtt/mqtt.go @@ -23,7 +23,7 @@ import ( const ( maxHeaderSize = 6 maxTopicSize = 1024 // max MQTT header size - maxMessageSize = 65536 // max MQTT message size is impossible to increase as per protocol (uint16 len) + MaxMessageSize = 65536 // max MQTT message size is impossible to increase as per protocol (uint16 len) ) // ErrMessageTooLarge occurs when a message encoded/decoded is larger than max MQTT frame. @@ -327,7 +327,7 @@ func (p *Publish) EncodeTo(w io.Writer) (int, error) { length += 2 } - if length > maxMessageSize { + if length > MaxMessageSize { return 0, ErrMessageTooLarge } diff --git a/internal/provider/storage/storage.go b/internal/provider/storage/storage.go index 6bb19dc2..da1822e8 100644 --- a/internal/provider/storage/storage.go +++ b/internal/provider/storage/storage.go @@ -71,7 +71,7 @@ type lookupQuery struct { UntilTime int64 // Lookup stops when reaches this time. UntilID message.ID // Lookup stops when reaches this message ID. LimitByCount *MessageNumberLimiter - //LimitBySize *MessageSizeLimiter + LimitBySize *MessageSizeLimiter } // newLookupQuery creates a new lookup query @@ -87,6 +87,8 @@ func newLookupQuery(ssid message.Ssid, from, until time.Time, untilID message.ID switch v := limiter.(type) { case *MessageNumberLimiter: query.LimitByCount = v + case *MessageSizeLimiter: + query.LimitBySize = v } return query } @@ -95,6 +97,8 @@ func (q *lookupQuery) Limiter() Limiter { switch { case q.LimitByCount != nil: return q.LimitByCount + case q.LimitBySize != nil: + return q.LimitBySize default: return &MessageNumberLimiter{} } @@ -109,13 +113,24 @@ type Limiter interface { // parameter in the Query() function. type MessageNumberLimiter struct { count int64 `binary:"-"` - MsgLimit int64 + MsgLimit int64 // TODO: why is this exported? } func (limiter *MessageNumberLimiter) Admit(m *message.Message) bool { - admit := limiter.count < limiter.MsgLimit + // As this function won't be called multiple times once the limit is reached, + // the following implementation should be faster than using a branching statement + // to check if the limit is reached, before incrementing the counter. limiter.count += 1 - return admit + return limiter.count <= limiter.MsgLimit + + // The following implementation would use a branching each time the function is called. + /* + if limiter.count < limiter.MsgLimit { + limiter.count += 1 + return true + } + return false + */ } func (limiter *MessageNumberLimiter) Limit(frame *message.Frame) { @@ -126,6 +141,46 @@ func NewMessageNumberLimiter(limit int64) Limiter { return &MessageNumberLimiter{MsgLimit: limit} } +// MessageSizeLimiter provide an Limiter implementation based on both the +// number of messages and the total size of the response. +type MessageSizeLimiter struct { + count int64 `binary:"-"` + size int64 `binary:"-"` + countLimit int64 + sizeLimit int64 +} + +func (limiter *MessageSizeLimiter) Admit(m *message.Message) bool { + // As this function won't be called multiple times once the limit is reached, + // the following implementation should be faster than using a branching statement + // to check if the limit is reached, before incrementing the counter. + // Todo: discuss whether it's ok to make that assumption + + // This size calculation comes from mqtt.go:EncodeTo() line 320. + // Todo: discuss whether this is the best way to calculate the size. + // 2 bytes for message ID. + limiter.size += int64(2 + len(m.Channel) + len(m.Payload)) + limiter.count += 1 + return limiter.count <= limiter.countLimit && limiter.size <= limiter.sizeLimit +} + +func (limiter *MessageSizeLimiter) Limit(frame *message.Frame) { + // Limit takes the last N elements that fit into a message, sorted by message time + frame.Sort() + for i, m := range *frame { + totalSize := int64(2 + len(m.Channel) + len(m.Payload)) + if limiter.size+totalSize > limiter.sizeLimit { + *frame = (*frame)[:i] + break + } + limiter.size += totalSize + } +} + +func NewMessageSizeLimiter(countLimit, sizeLimit int64) Limiter { + return &MessageSizeLimiter{countLimit: countLimit, sizeLimit: sizeLimit} +} + // configUint32 retrieves an uint32 from the config func configUint32(config map[string]interface{}, name string, defaultValue uint32) uint32 { if v, ok := config[name]; ok { diff --git a/internal/service/history/history.go b/internal/service/history/history.go index 1db61b27..cdbb60eb 100644 --- a/internal/service/history/history.go +++ b/internal/service/history/history.go @@ -19,6 +19,7 @@ import ( "github.com/emitter-io/emitter/internal/errors" "github.com/emitter-io/emitter/internal/message" + "github.com/emitter-io/emitter/internal/network/mqtt" "github.com/emitter-io/emitter/internal/provider/logging" "github.com/emitter-io/emitter/internal/provider/storage" "github.com/emitter-io/emitter/internal/security" @@ -76,7 +77,8 @@ func (s *Service) OnRequest(c service.Conn, payload []byte) (service.Response, b ssid := message.NewSsid(key.Contract(), channel.Query) t0, t1 := channel.Window() // Get the window - messageLimiter := storage.NewMessageNumberLimiter(limit) + //messageLimiter := storage.NewMessageNumberLimiter(limit) + messageLimiter := storage.NewMessageSizeLimiter(limit, mqtt.MaxMessageSize) msgs, err := s.store.Query(ssid, t0, t1, request.LastMessageID, messageLimiter) if err != nil { logging.LogError("conn", "query last messages", err)