Skip to content

Commit

Permalink
fix: add missing topic selector checks in BoltTransport (#597)
Browse files Browse the repository at this point in the history
* fix: re-add missing topic selector checks

* fix: use SubscriberList in the Bolt transport

* fix: bolt history

* fix linter
  • Loading branch information
dunglas committed Dec 26, 2021
1 parent 99bf84c commit b2adf74
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 38 deletions.
41 changes: 21 additions & 20 deletions bolt_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ const defaultBoltBucketName = "updates"
// BoltTransport implements the TransportInterface using the Bolt database.
type BoltTransport struct {
sync.RWMutex
subscribers *SubscriberList
logger Logger
db *bolt.DB
bucketName string
size uint64
cleanupFrequency float64
subscribers map[*Subscriber]struct{}
closed chan struct{}
closedOnce sync.Once
lastSeq uint64
Expand Down Expand Up @@ -83,7 +83,7 @@ func NewBoltTransport(u *url.URL, l Logger, tss *TopicSelectorStore) (Transport,
bucketName: bucketName,
size: size,
cleanupFrequency: cleanupFrequency,
subscribers: make(map[*Subscriber]struct{}),
subscribers: NewSubscriberList(1e5),
closed: make(chan struct{}),
lastEventID: getDBLastEventID(db, bucketName),
}, nil
Expand Down Expand Up @@ -129,9 +129,9 @@ func (t *BoltTransport) Dispatch(update *Update) error {
return err
}

for subscriber := range t.subscribers {
if !subscriber.Dispatch(update, false) {
delete(t.subscribers, subscriber)
for _, s := range t.subscribers.MatchAny(update) {
if !s.Dispatch(update, false) {
t.subscribers.Remove(s)
}
}

Expand Down Expand Up @@ -182,7 +182,7 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error {
}

t.Lock()
t.subscribers[s] = struct{}{}
t.subscribers.Add(s)
toSeq := t.lastSeq //nolint:ifshort
t.Unlock()

Expand All @@ -204,8 +204,8 @@ func (t *BoltTransport) RemoveSubscriber(s *Subscriber) error {
}

t.Lock()
delete(t.subscribers, s)
t.Unlock()
defer t.Unlock()
t.subscribers.Remove(s)

return nil
}
Expand All @@ -214,13 +214,13 @@ func (t *BoltTransport) RemoveSubscriber(s *Subscriber) error {
func (t *BoltTransport) GetSubscribers() (string, []*Subscriber, error) {
t.RLock()
defer t.RUnlock()
subscribers := make([]*Subscriber, len(t.subscribers))

i := 0
for subscriber := range t.subscribers {
subscribers[i] = subscriber
i++
}
var subscribers []*Subscriber
t.subscribers.Walk(0, func(s *Subscriber) bool {
subscribers = append(subscribers, s)

return true
})

return t.lastEventID, subscribers, nil
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) {
return fmt.Errorf("unable to unmarshal update: %w", err)
}

if !s.Dispatch(update, true) || (toSeq > 0 && binary.BigEndian.Uint64(k[:8]) >= toSeq) {
if (s.Match(update) && !s.Dispatch(update, true)) || (toSeq > 0 && binary.BigEndian.Uint64(k[:8]) >= toSeq) {
s.HistoryDispatched(responseLastEventID)

return nil
Expand All @@ -275,12 +275,13 @@ func (t *BoltTransport) Close() (err error) {
close(t.closed)

t.Lock()
for subscriber := range t.subscribers {
subscriber.Disconnect()
delete(t.subscribers, subscriber)
}
t.Unlock()
defer t.Unlock()

t.subscribers.Walk(0, func(s *Subscriber) bool {
s.Disconnect()

return true
})
err = t.db.Close()
})

Expand Down
53 changes: 43 additions & 10 deletions bolt_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ func TestBoltTransportHistory(t *testing.T) {
}
}

func TestBoltTopicSelectorHistory(t *testing.T) {
transport := createBoltTransport("bolt://test.db")
defer transport.Close()
defer os.Remove("test.db")

transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed"}, Event: Event{ID: "1"}})
transport.Dispatch(&Update{Topics: []string{"http://example.com/not-subscribed"}, Event: Event{ID: "2"}})
transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Private: true, Event: Event{ID: "3"}})
transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Event: Event{ID: "4"}})

s := NewSubscriber(EarliestLastEventID, transport.logger)
s.SetTopics([]string{"http://example.com/subscribed", "http://example.com/subscribed-public-only"}, []string{"http://example.com/subscribed"})

require.Nil(t, transport.AddSubscriber(s))

assert.Equal(t, "1", (<-s.Receive()).ID)
assert.Equal(t, "4", (<-s.Receive()).ID)
}

func TestBoltTransportRetrieveAllHistory(t *testing.T) {
transport := createBoltTransport("bolt://test.db")
defer transport.Close()
Expand Down Expand Up @@ -202,13 +221,25 @@ func TestBoltTransportDispatch(t *testing.T) {
assert.Implements(t, (*Transport)(nil), transport)

s := NewSubscriber("", transport.logger)
s.SetTopics([]string{"https://example.com/foo"}, nil)
s.SetTopics([]string{"https://example.com/foo", "https://example.com/private"}, []string{"https://example.com/private"})

require.Nil(t, transport.AddSubscriber(s))

u := &Update{Topics: s.Topics}
require.Nil(t, transport.Dispatch(u))
assert.Equal(t, u, <-s.Receive())
notSubscribed := &Update{Topics: []string{"not-subscribed"}}
require.Nil(t, transport.Dispatch(notSubscribed))

subscribedNotAuthorized := &Update{Topics: []string{"https://example.com/foo"}, Private: true}
require.Nil(t, transport.Dispatch(subscribedNotAuthorized))

public := &Update{Topics: s.Topics}
require.Nil(t, transport.Dispatch(public))

assert.Equal(t, public, <-s.Receive())

private := &Update{Topics: s.PrivateTopics, Private: true}
require.Nil(t, transport.Dispatch(private))

assert.Equal(t, private, <-s.Receive())
}

func TestBoltTransportClosed(t *testing.T) {
Expand Down Expand Up @@ -238,24 +269,26 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) {
defer os.Remove("test.db")

s1 := NewSubscriber("", transport.logger)
s1.SetTopics([]string{"foo"}, []string{})
require.Nil(t, transport.AddSubscriber(s1))

s2 := NewSubscriber("", transport.logger)
s2.SetTopics([]string{"foo"}, []string{})
require.Nil(t, transport.AddSubscriber(s2))

assert.Len(t, transport.subscribers, 2)
assert.Equal(t, 2, transport.subscribers.Len())

s1.Disconnect()
assert.Len(t, transport.subscribers, 2)
assert.Equal(t, 2, transport.subscribers.Len())

transport.Dispatch(&Update{Topics: s1.Topics})
assert.Len(t, transport.subscribers, 1)
assert.Equal(t, 1, transport.subscribers.Len())

s2.Disconnect()
assert.Len(t, transport.subscribers, 1)
assert.Equal(t, 1, transport.subscribers.Len())

transport.Dispatch(&Update{})
assert.Len(t, transport.subscribers, 0)
transport.Dispatch(&Update{Topics: s1.Topics})
assert.Zero(t, transport.subscribers.Len())
}

func TestBoltGetSubscribers(t *testing.T) {
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func TestClientClosesThenReconnects(t *testing.T) {
publish := func(data string, waitForSubscribers int) {
for {
transport.Lock()
l := len(transport.subscribers)
l := transport.subscribers.Len()
transport.Unlock()
if l >= waitForSubscribers {
break
Expand Down
4 changes: 2 additions & 2 deletions subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ func TestUnknownLastEventID(t *testing.T) {

for {
transport.RLock()
done := len(transport.subscribers) == 2
done := transport.subscribers.Len() == 2
transport.RUnlock()

if done {
Expand Down Expand Up @@ -709,7 +709,7 @@ func TestUnknownLastEventIDEmptyHistory(t *testing.T) {

for {
transport.RLock()
done := len(transport.subscribers) == 2
done := transport.subscribers.Len() == 2
transport.RUnlock()

if done {
Expand Down
19 changes: 16 additions & 3 deletions subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func NewSubscriber(lastEventID string, logger Logger) *Subscriber {
}

// Dispatch an update to the subscriber.
// Security checks must (topics matching) be done before calling Dispatch,
// for instance by calling Match.
func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool {
if atomic.LoadInt32(&s.disconnected) > 0 {
return false
Expand Down Expand Up @@ -150,8 +152,8 @@ func escapeTopics(topics []string) []string {
return escapedTopics
}

// Match checks if the current subscriber can access to the given topic.
func (s *Subscriber) Match(topic string, private bool) (match bool) {
// MatchTopic checks if the current subscriber can access to the given topic.
func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) {
for i, ts := range s.Topics {
if ts == "*" || ts == topic {
match = true
Expand Down Expand Up @@ -188,11 +190,22 @@ func (s *Subscriber) Match(topic string, private bool) (match bool) {
return false
}

// Match checks if the current subscriber can receive the given update.
func (s *Subscriber) Match(u *Update) bool {
for _, t := range u.Topics {
if s.MatchTopic(t, u.Private) {
return true
}
}

return false
}

// getSubscriptions return the list of subscriptions associated to this subscriber.
func (s *Subscriber) getSubscriptions(topic, context string, active bool) []subscription {
var subscriptions []subscription //nolint:prealloc
for k, t := range s.Topics {
if topic != "" && !s.Match(topic, false) {
if topic != "" && !s.MatchTopic(topic, false) {
continue
}

Expand Down
2 changes: 1 addition & 1 deletion subscriber_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func NewSubscriberList(size int) *SubscriberList {
return false
}

return s.(*Subscriber).Match(p[1], p[0] == "p")
return s.(*Subscriber).MatchTopic(p[1], p[0] == "p")
}, size),
}
}
Expand Down
38 changes: 38 additions & 0 deletions topic_selector_lru_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package mercure

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestMatchLRU(t *testing.T) {
tss, err := NewTopicSelectorStoreLRU(DefaultTopicSelectorStoreLRUMaxEntriesPerShard, DefaultTopicSelectorStoreLRUMaxEntriesPerShard)
require.Nil(t, err)

assert.False(t, tss.match("foo", "bar"))

assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar"))

_, found := tss.cache.Get("t_https://example.com/{foo}/bar")
assert.True(t, found)

_, found = tss.cache.Get("m_https://example.com/{foo}/bar_https://example.com/foo/bar")
assert.True(t, found)

assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar"))
assert.False(t, tss.match("https://example.com/foo/bar/baz", "https://example.com/{foo}/bar"))

_, found = tss.cache.Get("t_https://example.com/{foo}/bar")
assert.True(t, found)

_, found = tss.cache.Get("m_https://example.com/{foo}/bar_https://example.com/foo/bar")
assert.True(t, found)

assert.True(t, tss.match("https://example.com/kevin/dunglas", "https://example.com/{fistname}/{lastname}"))
assert.True(t, tss.match("https://example.com/foo/bar", "*"))
assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/foo/bar"))
assert.True(t, tss.match("foo", "foo"))
assert.False(t, tss.match("foo", "bar"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestMatch(t *testing.T) {
func TestMatchRistretto(t *testing.T) {
cache, _ := ristretto.NewCache(&ristretto.Config{
NumCounters: TopicSelectorStoreRistrettoDefaultCacheNumCounters,
MaxCost: TopicSelectorStoreRistrettoCacheMaxCost,
Expand Down

0 comments on commit b2adf74

Please sign in to comment.