From b2adf74a014004d9e79dd62b754ba5ce59ddc45a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Sun, 26 Dec 2021 15:20:40 +0100 Subject: [PATCH] fix: add missing topic selector checks in BoltTransport (#597) * fix: re-add missing topic selector checks * fix: use SubscriberList in the Bolt transport * fix: bolt history * fix linter --- bolt_transport.go | 41 +++++++------- bolt_transport_test.go | 53 +++++++++++++++---- transport_test.go => local_transport_test.go | 0 server_test.go | 2 +- subscribe_test.go | 4 +- subscriber.go | 19 +++++-- subscriber_list.go | 2 +- topic_selector_lru_test.go | 38 +++++++++++++ ...est.go => topic_selector_ristretto_test.go | 2 +- 9 files changed, 123 insertions(+), 38 deletions(-) rename transport_test.go => local_transport_test.go (100%) create mode 100644 topic_selector_lru_test.go rename topic_selector_test.go => topic_selector_ristretto_test.go (97%) diff --git a/bolt_transport.go b/bolt_transport.go index 81472f7f..606e01de 100644 --- a/bolt_transport.go +++ b/bolt_transport.go @@ -26,12 +26,12 @@ const defaultBoltBucketName = "updates" // BoltTransport implements the TransportInterface using the Bolt database. type BoltTransport struct { sync.RWMutex + subscribers *SubscriberList logger Logger db *bolt.DB bucketName string size uint64 cleanupFrequency float64 - subscribers map[*Subscriber]struct{} closed chan struct{} closedOnce sync.Once lastSeq uint64 @@ -83,7 +83,7 @@ func NewBoltTransport(u *url.URL, l Logger, tss *TopicSelectorStore) (Transport, bucketName: bucketName, size: size, cleanupFrequency: cleanupFrequency, - subscribers: make(map[*Subscriber]struct{}), + subscribers: NewSubscriberList(1e5), closed: make(chan struct{}), lastEventID: getDBLastEventID(db, bucketName), }, nil @@ -129,9 +129,9 @@ func (t *BoltTransport) Dispatch(update *Update) error { return err } - for subscriber := range t.subscribers { - if !subscriber.Dispatch(update, false) { - delete(t.subscribers, subscriber) + for _, s := range t.subscribers.MatchAny(update) { + if !s.Dispatch(update, false) { + t.subscribers.Remove(s) } } @@ -182,7 +182,7 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error { } t.Lock() - t.subscribers[s] = struct{}{} + t.subscribers.Add(s) toSeq := t.lastSeq //nolint:ifshort t.Unlock() @@ -204,8 +204,8 @@ func (t *BoltTransport) RemoveSubscriber(s *Subscriber) error { } t.Lock() - delete(t.subscribers, s) - t.Unlock() + defer t.Unlock() + t.subscribers.Remove(s) return nil } @@ -214,13 +214,13 @@ func (t *BoltTransport) RemoveSubscriber(s *Subscriber) error { func (t *BoltTransport) GetSubscribers() (string, []*Subscriber, error) { t.RLock() defer t.RUnlock() - subscribers := make([]*Subscriber, len(t.subscribers)) - i := 0 - for subscriber := range t.subscribers { - subscribers[i] = subscriber - i++ - } + var subscribers []*Subscriber + t.subscribers.Walk(0, func(s *Subscriber) bool { + subscribers = append(subscribers, s) + + return true + }) return t.lastEventID, subscribers, nil } @@ -257,7 +257,7 @@ func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) { return fmt.Errorf("unable to unmarshal update: %w", err) } - if !s.Dispatch(update, true) || (toSeq > 0 && binary.BigEndian.Uint64(k[:8]) >= toSeq) { + if (s.Match(update) && !s.Dispatch(update, true)) || (toSeq > 0 && binary.BigEndian.Uint64(k[:8]) >= toSeq) { s.HistoryDispatched(responseLastEventID) return nil @@ -275,12 +275,13 @@ func (t *BoltTransport) Close() (err error) { close(t.closed) t.Lock() - for subscriber := range t.subscribers { - subscriber.Disconnect() - delete(t.subscribers, subscriber) - } - t.Unlock() + defer t.Unlock() + + t.subscribers.Walk(0, func(s *Subscriber) bool { + s.Disconnect() + return true + }) err = t.db.Close() }) diff --git a/bolt_transport_test.go b/bolt_transport_test.go index 7cca0b05..84cab032 100644 --- a/bolt_transport_test.go +++ b/bolt_transport_test.go @@ -52,6 +52,25 @@ func TestBoltTransportHistory(t *testing.T) { } } +func TestBoltTopicSelectorHistory(t *testing.T) { + transport := createBoltTransport("bolt://test.db") + defer transport.Close() + defer os.Remove("test.db") + + transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed"}, Event: Event{ID: "1"}}) + transport.Dispatch(&Update{Topics: []string{"http://example.com/not-subscribed"}, Event: Event{ID: "2"}}) + transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Private: true, Event: Event{ID: "3"}}) + transport.Dispatch(&Update{Topics: []string{"http://example.com/subscribed-public-only"}, Event: Event{ID: "4"}}) + + s := NewSubscriber(EarliestLastEventID, transport.logger) + s.SetTopics([]string{"http://example.com/subscribed", "http://example.com/subscribed-public-only"}, []string{"http://example.com/subscribed"}) + + require.Nil(t, transport.AddSubscriber(s)) + + assert.Equal(t, "1", (<-s.Receive()).ID) + assert.Equal(t, "4", (<-s.Receive()).ID) +} + func TestBoltTransportRetrieveAllHistory(t *testing.T) { transport := createBoltTransport("bolt://test.db") defer transport.Close() @@ -202,13 +221,25 @@ func TestBoltTransportDispatch(t *testing.T) { assert.Implements(t, (*Transport)(nil), transport) s := NewSubscriber("", transport.logger) - s.SetTopics([]string{"https://example.com/foo"}, nil) + s.SetTopics([]string{"https://example.com/foo", "https://example.com/private"}, []string{"https://example.com/private"}) require.Nil(t, transport.AddSubscriber(s)) - u := &Update{Topics: s.Topics} - require.Nil(t, transport.Dispatch(u)) - assert.Equal(t, u, <-s.Receive()) + notSubscribed := &Update{Topics: []string{"not-subscribed"}} + require.Nil(t, transport.Dispatch(notSubscribed)) + + subscribedNotAuthorized := &Update{Topics: []string{"https://example.com/foo"}, Private: true} + require.Nil(t, transport.Dispatch(subscribedNotAuthorized)) + + public := &Update{Topics: s.Topics} + require.Nil(t, transport.Dispatch(public)) + + assert.Equal(t, public, <-s.Receive()) + + private := &Update{Topics: s.PrivateTopics, Private: true} + require.Nil(t, transport.Dispatch(private)) + + assert.Equal(t, private, <-s.Receive()) } func TestBoltTransportClosed(t *testing.T) { @@ -238,24 +269,26 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { defer os.Remove("test.db") s1 := NewSubscriber("", transport.logger) + s1.SetTopics([]string{"foo"}, []string{}) require.Nil(t, transport.AddSubscriber(s1)) s2 := NewSubscriber("", transport.logger) + s2.SetTopics([]string{"foo"}, []string{}) require.Nil(t, transport.AddSubscriber(s2)) - assert.Len(t, transport.subscribers, 2) + assert.Equal(t, 2, transport.subscribers.Len()) s1.Disconnect() - assert.Len(t, transport.subscribers, 2) + assert.Equal(t, 2, transport.subscribers.Len()) transport.Dispatch(&Update{Topics: s1.Topics}) - assert.Len(t, transport.subscribers, 1) + assert.Equal(t, 1, transport.subscribers.Len()) s2.Disconnect() - assert.Len(t, transport.subscribers, 1) + assert.Equal(t, 1, transport.subscribers.Len()) - transport.Dispatch(&Update{}) - assert.Len(t, transport.subscribers, 0) + transport.Dispatch(&Update{Topics: s1.Topics}) + assert.Zero(t, transport.subscribers.Len()) } func TestBoltGetSubscribers(t *testing.T) { diff --git a/transport_test.go b/local_transport_test.go similarity index 100% rename from transport_test.go rename to local_transport_test.go diff --git a/server_test.go b/server_test.go index a04d877d..2879efbe 100644 --- a/server_test.go +++ b/server_test.go @@ -268,7 +268,7 @@ func TestClientClosesThenReconnects(t *testing.T) { publish := func(data string, waitForSubscribers int) { for { transport.Lock() - l := len(transport.subscribers) + l := transport.subscribers.Len() transport.Unlock() if l >= waitForSubscribers { break diff --git a/subscribe_test.go b/subscribe_test.go index 878fcd4f..9a2acff0 100644 --- a/subscribe_test.go +++ b/subscribe_test.go @@ -641,7 +641,7 @@ func TestUnknownLastEventID(t *testing.T) { for { transport.RLock() - done := len(transport.subscribers) == 2 + done := transport.subscribers.Len() == 2 transport.RUnlock() if done { @@ -709,7 +709,7 @@ func TestUnknownLastEventIDEmptyHistory(t *testing.T) { for { transport.RLock() - done := len(transport.subscribers) == 2 + done := transport.subscribers.Len() == 2 transport.RUnlock() if done { diff --git a/subscriber.go b/subscriber.go index 97d7d918..0b0a34f1 100644 --- a/subscriber.go +++ b/subscriber.go @@ -52,6 +52,8 @@ func NewSubscriber(lastEventID string, logger Logger) *Subscriber { } // Dispatch an update to the subscriber. +// Security checks must (topics matching) be done before calling Dispatch, +// for instance by calling Match. func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { if atomic.LoadInt32(&s.disconnected) > 0 { return false @@ -150,8 +152,8 @@ func escapeTopics(topics []string) []string { return escapedTopics } -// Match checks if the current subscriber can access to the given topic. -func (s *Subscriber) Match(topic string, private bool) (match bool) { +// MatchTopic checks if the current subscriber can access to the given topic. +func (s *Subscriber) MatchTopic(topic string, private bool) (match bool) { for i, ts := range s.Topics { if ts == "*" || ts == topic { match = true @@ -188,11 +190,22 @@ func (s *Subscriber) Match(topic string, private bool) (match bool) { return false } +// Match checks if the current subscriber can receive the given update. +func (s *Subscriber) Match(u *Update) bool { + for _, t := range u.Topics { + if s.MatchTopic(t, u.Private) { + return true + } + } + + return false +} + // getSubscriptions return the list of subscriptions associated to this subscriber. func (s *Subscriber) getSubscriptions(topic, context string, active bool) []subscription { var subscriptions []subscription //nolint:prealloc for k, t := range s.Topics { - if topic != "" && !s.Match(topic, false) { + if topic != "" && !s.MatchTopic(topic, false) { continue } diff --git a/subscriber_list.go b/subscriber_list.go index 0ffb3dd8..a24884d8 100644 --- a/subscriber_list.go +++ b/subscriber_list.go @@ -18,7 +18,7 @@ func NewSubscriberList(size int) *SubscriberList { return false } - return s.(*Subscriber).Match(p[1], p[0] == "p") + return s.(*Subscriber).MatchTopic(p[1], p[0] == "p") }, size), } } diff --git a/topic_selector_lru_test.go b/topic_selector_lru_test.go new file mode 100644 index 00000000..f2d8d52d --- /dev/null +++ b/topic_selector_lru_test.go @@ -0,0 +1,38 @@ +package mercure + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMatchLRU(t *testing.T) { + tss, err := NewTopicSelectorStoreLRU(DefaultTopicSelectorStoreLRUMaxEntriesPerShard, DefaultTopicSelectorStoreLRUMaxEntriesPerShard) + require.Nil(t, err) + + assert.False(t, tss.match("foo", "bar")) + + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar")) + + _, found := tss.cache.Get("t_https://example.com/{foo}/bar") + assert.True(t, found) + + _, found = tss.cache.Get("m_https://example.com/{foo}/bar_https://example.com/foo/bar") + assert.True(t, found) + + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar")) + assert.False(t, tss.match("https://example.com/foo/bar/baz", "https://example.com/{foo}/bar")) + + _, found = tss.cache.Get("t_https://example.com/{foo}/bar") + assert.True(t, found) + + _, found = tss.cache.Get("m_https://example.com/{foo}/bar_https://example.com/foo/bar") + assert.True(t, found) + + assert.True(t, tss.match("https://example.com/kevin/dunglas", "https://example.com/{fistname}/{lastname}")) + assert.True(t, tss.match("https://example.com/foo/bar", "*")) + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/foo/bar")) + assert.True(t, tss.match("foo", "foo")) + assert.False(t, tss.match("foo", "bar")) +} diff --git a/topic_selector_test.go b/topic_selector_ristretto_test.go similarity index 97% rename from topic_selector_test.go rename to topic_selector_ristretto_test.go index fb89730f..6275fbea 100644 --- a/topic_selector_test.go +++ b/topic_selector_ristretto_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMatch(t *testing.T) { +func TestMatchRistretto(t *testing.T) { cache, _ := ristretto.NewCache(&ristretto.Config{ NumCounters: TopicSelectorStoreRistrettoDefaultCacheNumCounters, MaxCost: TopicSelectorStoreRistrettoCacheMaxCost,