Skip to content

Commit

Permalink
refactor: TransportSubscribers.GetSubscribers() can return an error
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas committed Jan 18, 2021
1 parent 309e28f commit c11b406
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 17 deletions.
6 changes: 3 additions & 3 deletions bolt_transport.go
Expand Up @@ -191,18 +191,18 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error {
}

// GetSubscribers get the list of active subscribers.
func (t *BoltTransport) GetSubscribers() (lastEventID string, subscribers []*Subscriber) {
func (t *BoltTransport) GetSubscribers() (string, []*Subscriber, error) {
t.RLock()
defer t.RUnlock()
subscribers = make([]*Subscriber, len(t.subscribers))
subscribers := make([]*Subscriber, len(t.subscribers))

i := 0
for subscriber := range t.subscribers {
subscribers[i] = subscriber
i++
}

return t.lastEventID, subscribers
return t.lastEventID, subscribers, nil
}

func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) {
Expand Down
5 changes: 3 additions & 2 deletions bolt_transport_test.go
Expand Up @@ -289,11 +289,12 @@ func TestBoltGetSubscribers(t *testing.T) {
go s2.start()
require.Nil(t, transport.AddSubscriber(s2))

lastEventID, subscribers := transport.GetSubscribers()
lastEventID, subscribers, err := transport.GetSubscribers()
assert.Equal(t, EarliestLastEventID, lastEventID)
assert.Len(t, subscribers, 2)
assert.Contains(t, subscribers, s1)
assert.Contains(t, subscribers, s2)
assert.Nil(t, err)
}

func TestBoltLastEventID(t *testing.T) {
Expand Down Expand Up @@ -325,6 +326,6 @@ func TestBoltLastEventID(t *testing.T) {
require.NotNil(t, transport)
defer transport.Close()

lastEventID, _ := transport.GetSubscribers()
lastEventID, _, _ := transport.GetSubscribers()
assert.Equal(t, "foo", lastEventID)
}
6 changes: 3 additions & 3 deletions local_transport.go
Expand Up @@ -68,18 +68,18 @@ func (t *LocalTransport) AddSubscriber(s *Subscriber) error {
}

// GetSubscribers get the list of active subscribers.
func (t *LocalTransport) GetSubscribers() (lastEventID string, subscribers []*Subscriber) {
func (t *LocalTransport) GetSubscribers() (string, []*Subscriber, error) {
t.RLock()
defer t.RUnlock()
subscribers = make([]*Subscriber, len(t.subscribers))
subscribers := make([]*Subscriber, len(t.subscribers))

i := 0
for subscriber := range t.subscribers {
subscribers[i] = subscriber
i++
}

return t.lastEventID, subscribers
return t.lastEventID, subscribers, nil
}

// Close closes the Transport.
Expand Down
6 changes: 3 additions & 3 deletions subscribe_test.go
Expand Up @@ -167,8 +167,8 @@ func (*addSubscriberErrorTransport) AddSubscriber(*Subscriber) error {
return errFailedToAddSubscriber
}

func (*addSubscriberErrorTransport) GetSubscribers() (string, []*Subscriber) {
return "", []*Subscriber{}
func (*addSubscriberErrorTransport) GetSubscribers() (string, []*Subscriber, error) {
return "", []*Subscriber{}, nil
}

func (*addSubscriberErrorTransport) Close() error {
Expand Down Expand Up @@ -390,7 +390,7 @@ func TestSubscriptionEvents(t *testing.T) {
defer wg.Done()

for {
_, s := hub.transport.(TransportSubscribers).GetSubscribers()
_, s, _ := hub.transport.(TransportSubscribers).GetSubscribers()
if len(s) == 2 {
break
}
Expand Down
9 changes: 8 additions & 1 deletion subscription.go
Expand Up @@ -118,7 +118,14 @@ func (h *Hub) initSubscription(currentURL string, w http.ResponseWriter, r *http
panic("The transport isn't an instance of hub.TransportSubscribers")
}

lastEventID, subscribers = transport.GetSubscribers()
var err error
lastEventID, subscribers, err = transport.GetSubscribers()
if err != nil {
h.logger.Error("Error retrieving subscribers", zap.Error(err))
w.WriteHeader(http.StatusInternalServerError)

return
}
if r.Header.Get("If-None-Match") == lastEventID {
w.WriteHeader(http.StatusNotModified)

Expand Down
6 changes: 3 additions & 3 deletions subscription_test.go
Expand Up @@ -111,7 +111,7 @@ func TestSubscriptionsHandler(t *testing.T) {
assert.Equal(t, subscriptionsURL, subscriptions.ID)
assert.Equal(t, "Subscriptions", subscriptions.Type)

lastEventID, subscribers := hub.transport.(TransportSubscribers).GetSubscribers()
lastEventID, subscribers, _ := hub.transport.(TransportSubscribers).GetSubscribers()

assert.Equal(t, lastEventID, subscriptions.LastEventID)
require.NotEmpty(t, subscribers)
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestSubscriptionsHandlerForTopic(t *testing.T) {
assert.Equal(t, defaultHubURL+"/subscriptions/"+escapedBarTopic, subscriptions.ID)
assert.Equal(t, "Subscriptions", subscriptions.Type)

lastEventID, subscribers := hub.transport.(TransportSubscribers).GetSubscribers()
lastEventID, subscribers, _ := hub.transport.(TransportSubscribers).GetSubscribers()

assert.Equal(t, lastEventID, subscriptions.LastEventID)
require.NotEmpty(t, subscribers)
Expand Down Expand Up @@ -204,7 +204,7 @@ func TestSubscriptionHandler(t *testing.T) {
var subscription subscription
json.Unmarshal(w.Body.Bytes(), &subscription)
expectedSub := s.getSubscriptions(s.Topics[1], "https://mercure.rocks/", true)[1]
expectedSub.LastEventID, _ = hub.transport.(TransportSubscribers).GetSubscribers()
expectedSub.LastEventID, _, _ = hub.transport.(TransportSubscribers).GetSubscribers()
assert.Equal(t, expectedSub, subscription)

req = httptest.NewRequest("GET", defaultHubURL+"/subscriptions/notexist/"+s.EscapedID, nil)
Expand Down
2 changes: 1 addition & 1 deletion transport.go
Expand Up @@ -51,7 +51,7 @@ type Transport interface {
// TransportSubscribers provide a method to retrieve the list of active subscribers.
type TransportSubscribers interface {
// GetSubscribers gets the last event ID and the list of active subscribers at this time.
GetSubscribers() (string, []*Subscriber)
GetSubscribers() (string, []*Subscriber, error)
}

// ErrClosedTransport is returned by the Transport's Dispatch and AddSubscriber methods after a call to Close.
Expand Down
3 changes: 2 additions & 1 deletion transport_test.go
Expand Up @@ -138,9 +138,10 @@ func TestLocalTransportGetSubscribers(t *testing.T) {
go s2.start()
require.Nil(t, transport.AddSubscriber(s2))

lastEventID, subscribers := transport.(TransportSubscribers).GetSubscribers()
lastEventID, subscribers, err := transport.(TransportSubscribers).GetSubscribers()
assert.Equal(t, EarliestLastEventID, lastEventID)
assert.Len(t, subscribers, 2)
assert.Contains(t, subscribers, s1)
assert.Contains(t, subscribers, s2)
assert.Nil(t, err)
}

0 comments on commit c11b406

Please sign in to comment.