From c11b406af6e20038f2f8f9e850e7c41f6b68d08c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Mon, 18 Jan 2021 21:34:02 +0100 Subject: [PATCH] refactor: TransportSubscribers.GetSubscribers() can return an error --- bolt_transport.go | 6 +++--- bolt_transport_test.go | 5 +++-- local_transport.go | 6 +++--- subscribe_test.go | 6 +++--- subscription.go | 9 ++++++++- subscription_test.go | 6 +++--- transport.go | 2 +- transport_test.go | 3 ++- 8 files changed, 26 insertions(+), 17 deletions(-) diff --git a/bolt_transport.go b/bolt_transport.go index a4571c40..74dd4487 100644 --- a/bolt_transport.go +++ b/bolt_transport.go @@ -191,10 +191,10 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error { } // GetSubscribers get the list of active subscribers. -func (t *BoltTransport) GetSubscribers() (lastEventID string, subscribers []*Subscriber) { +func (t *BoltTransport) GetSubscribers() (string, []*Subscriber, error) { t.RLock() defer t.RUnlock() - subscribers = make([]*Subscriber, len(t.subscribers)) + subscribers := make([]*Subscriber, len(t.subscribers)) i := 0 for subscriber := range t.subscribers { @@ -202,7 +202,7 @@ func (t *BoltTransport) GetSubscribers() (lastEventID string, subscribers []*Sub i++ } - return t.lastEventID, subscribers + return t.lastEventID, subscribers, nil } func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) { diff --git a/bolt_transport_test.go b/bolt_transport_test.go index 04021556..a45948cd 100644 --- a/bolt_transport_test.go +++ b/bolt_transport_test.go @@ -289,11 +289,12 @@ func TestBoltGetSubscribers(t *testing.T) { go s2.start() require.Nil(t, transport.AddSubscriber(s2)) - lastEventID, subscribers := transport.GetSubscribers() + lastEventID, subscribers, err := transport.GetSubscribers() assert.Equal(t, EarliestLastEventID, lastEventID) assert.Len(t, subscribers, 2) assert.Contains(t, subscribers, s1) assert.Contains(t, subscribers, s2) + assert.Nil(t, err) } func TestBoltLastEventID(t *testing.T) { @@ -325,6 +326,6 @@ func TestBoltLastEventID(t *testing.T) { require.NotNil(t, transport) defer transport.Close() - lastEventID, _ := transport.GetSubscribers() + lastEventID, _, _ := transport.GetSubscribers() assert.Equal(t, "foo", lastEventID) } diff --git a/local_transport.go b/local_transport.go index c2b4699e..cddb561f 100644 --- a/local_transport.go +++ b/local_transport.go @@ -68,10 +68,10 @@ func (t *LocalTransport) AddSubscriber(s *Subscriber) error { } // GetSubscribers get the list of active subscribers. -func (t *LocalTransport) GetSubscribers() (lastEventID string, subscribers []*Subscriber) { +func (t *LocalTransport) GetSubscribers() (string, []*Subscriber, error) { t.RLock() defer t.RUnlock() - subscribers = make([]*Subscriber, len(t.subscribers)) + subscribers := make([]*Subscriber, len(t.subscribers)) i := 0 for subscriber := range t.subscribers { @@ -79,7 +79,7 @@ func (t *LocalTransport) GetSubscribers() (lastEventID string, subscribers []*Su i++ } - return t.lastEventID, subscribers + return t.lastEventID, subscribers, nil } // Close closes the Transport. diff --git a/subscribe_test.go b/subscribe_test.go index 55d78b09..59fe9115 100644 --- a/subscribe_test.go +++ b/subscribe_test.go @@ -167,8 +167,8 @@ func (*addSubscriberErrorTransport) AddSubscriber(*Subscriber) error { return errFailedToAddSubscriber } -func (*addSubscriberErrorTransport) GetSubscribers() (string, []*Subscriber) { - return "", []*Subscriber{} +func (*addSubscriberErrorTransport) GetSubscribers() (string, []*Subscriber, error) { + return "", []*Subscriber{}, nil } func (*addSubscriberErrorTransport) Close() error { @@ -390,7 +390,7 @@ func TestSubscriptionEvents(t *testing.T) { defer wg.Done() for { - _, s := hub.transport.(TransportSubscribers).GetSubscribers() + _, s, _ := hub.transport.(TransportSubscribers).GetSubscribers() if len(s) == 2 { break } diff --git a/subscription.go b/subscription.go index 430b3bd8..973f917a 100644 --- a/subscription.go +++ b/subscription.go @@ -118,7 +118,14 @@ func (h *Hub) initSubscription(currentURL string, w http.ResponseWriter, r *http panic("The transport isn't an instance of hub.TransportSubscribers") } - lastEventID, subscribers = transport.GetSubscribers() + var err error + lastEventID, subscribers, err = transport.GetSubscribers() + if err != nil { + h.logger.Error("Error retrieving subscribers", zap.Error(err)) + w.WriteHeader(http.StatusInternalServerError) + + return + } if r.Header.Get("If-None-Match") == lastEventID { w.WriteHeader(http.StatusNotModified) diff --git a/subscription_test.go b/subscription_test.go index 4dcbe3fd..f567b1d3 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -111,7 +111,7 @@ func TestSubscriptionsHandler(t *testing.T) { assert.Equal(t, subscriptionsURL, subscriptions.ID) assert.Equal(t, "Subscriptions", subscriptions.Type) - lastEventID, subscribers := hub.transport.(TransportSubscribers).GetSubscribers() + lastEventID, subscribers, _ := hub.transport.(TransportSubscribers).GetSubscribers() assert.Equal(t, lastEventID, subscriptions.LastEventID) require.NotEmpty(t, subscribers) @@ -161,7 +161,7 @@ func TestSubscriptionsHandlerForTopic(t *testing.T) { assert.Equal(t, defaultHubURL+"/subscriptions/"+escapedBarTopic, subscriptions.ID) assert.Equal(t, "Subscriptions", subscriptions.Type) - lastEventID, subscribers := hub.transport.(TransportSubscribers).GetSubscribers() + lastEventID, subscribers, _ := hub.transport.(TransportSubscribers).GetSubscribers() assert.Equal(t, lastEventID, subscriptions.LastEventID) require.NotEmpty(t, subscribers) @@ -204,7 +204,7 @@ func TestSubscriptionHandler(t *testing.T) { var subscription subscription json.Unmarshal(w.Body.Bytes(), &subscription) expectedSub := s.getSubscriptions(s.Topics[1], "https://mercure.rocks/", true)[1] - expectedSub.LastEventID, _ = hub.transport.(TransportSubscribers).GetSubscribers() + expectedSub.LastEventID, _, _ = hub.transport.(TransportSubscribers).GetSubscribers() assert.Equal(t, expectedSub, subscription) req = httptest.NewRequest("GET", defaultHubURL+"/subscriptions/notexist/"+s.EscapedID, nil) diff --git a/transport.go b/transport.go index 9b52c7b9..825fa64f 100644 --- a/transport.go +++ b/transport.go @@ -51,7 +51,7 @@ type Transport interface { // TransportSubscribers provide a method to retrieve the list of active subscribers. type TransportSubscribers interface { // GetSubscribers gets the last event ID and the list of active subscribers at this time. - GetSubscribers() (string, []*Subscriber) + GetSubscribers() (string, []*Subscriber, error) } // ErrClosedTransport is returned by the Transport's Dispatch and AddSubscriber methods after a call to Close. diff --git a/transport_test.go b/transport_test.go index ceb36ebf..b40375a4 100644 --- a/transport_test.go +++ b/transport_test.go @@ -138,9 +138,10 @@ func TestLocalTransportGetSubscribers(t *testing.T) { go s2.start() require.Nil(t, transport.AddSubscriber(s2)) - lastEventID, subscribers := transport.(TransportSubscribers).GetSubscribers() + lastEventID, subscribers, err := transport.(TransportSubscribers).GetSubscribers() assert.Equal(t, EarliestLastEventID, lastEventID) assert.Len(t, subscribers, 2) assert.Contains(t, subscribers, s1) assert.Contains(t, subscribers, s2) + assert.Nil(t, err) }