From 036cb10c7313a61fbcfc89f33dd9131c4bafe3a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Thu, 7 May 2020 12:52:15 +0200 Subject: [PATCH] API cleanup --- hub/bolt_transport.go | 10 +-- hub/bolt_transport_test.go | 48 +++++----- hub/log.go | 4 +- hub/metrics.go | 4 +- hub/metrics_test.go | 16 ++-- hub/publish_test.go | 20 ++--- hub/subscribe.go | 64 +++++++------- hub/subscriber.go | 173 +++++++++++++++++++++---------------- hub/subscriber_test.go | 37 +++++++- hub/transport_test.go | 40 ++++----- 10 files changed, 234 insertions(+), 182 deletions(-) diff --git a/hub/bolt_transport.go b/hub/bolt_transport.go index 6e5ad5d6..6b0e96d3 100644 --- a/hub/bolt_transport.go +++ b/hub/bolt_transport.go @@ -149,21 +149,21 @@ func (t *BoltTransport) AddSubscriber(s *Subscriber) error { t.Lock() t.subscribers[s] = struct{}{} - if s.History.In == nil { + if s.LastEventID == "" { t.Unlock() return nil } t.Unlock() toSeq := t.lastSeq.Load() - t.dispatchFromHistory(s.lastEventID, toSeq, s) + t.dispatchHistory(s, toSeq) return nil } -func (t *BoltTransport) dispatchFromHistory(lastEventID string, toSeq uint64, s *Subscriber) { +func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) { t.db.View(func(tx *bolt.Tx) error { - defer close(s.History.In) + defer s.HistoryDispatched() b := tx.Bucket([]byte(t.bucketName)) if b == nil { return nil // No data @@ -173,7 +173,7 @@ func (t *BoltTransport) dispatchFromHistory(lastEventID string, toSeq uint64, s afterFromID := false for k, v := c.First(); k != nil; k, v = c.Next() { if !afterFromID { - if string(k[8:]) == lastEventID { + if string(k[8:]) == s.LastEventID { afterFromID = true } diff --git a/hub/bolt_transport_test.go b/hub/bolt_transport_test.go index 849bd30c..9e27b253 100644 --- a/hub/bolt_transport_test.go +++ b/hub/bolt_transport_test.go @@ -26,11 +26,9 @@ func TestBoltTransportHistory(t *testing.T) { }) } - s := newSubscriber() - s.topics = topics - s.rawTopics = topics - s.lastEventID = "8" - s.History.In = make(chan *Update) + s := newSubscriber("8") + s.Topics = topics + s.RawTopics = topics go s.start() err := transport.AddSubscriber(s) @@ -38,7 +36,7 @@ func TestBoltTransportHistory(t *testing.T) { var count int for { - u := <-s.Out + u := <-s.Receive() // the reading loop must read the #9 and #10 messages assert.Equal(t, strconv.Itoa(9+count), u.ID) count++ @@ -62,11 +60,9 @@ func TestBoltTransportHistoryAndLive(t *testing.T) { }) } - s := newSubscriber() - s.topics = topics - s.rawTopics = topics - s.lastEventID = "8" - s.History.In = make(chan *Update) + s := newSubscriber("8") + s.Topics = topics + s.RawTopics = topics go s.start() err := transport.AddSubscriber(s) @@ -78,7 +74,7 @@ func TestBoltTransportHistoryAndLive(t *testing.T) { defer wg.Done() var count int for { - u := <-s.Out + u := <-s.Receive() // the reading loop must read the #9, #10 and #11 messages assert.Equal(t, strconv.Itoa(9+count), u.ID) @@ -152,7 +148,7 @@ func TestBoltTransportDoNotDispatchedUntilListen(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() + s := newSubscriber("") go s.start() err := transport.AddSubscriber(s) @@ -166,7 +162,7 @@ func TestBoltTransportDoNotDispatchedUntilListen(t *testing.T) { wg.Add(1) go func() { select { - case readUpdate = <-s.Out: + case readUpdate = <-s.Receive(): case <-s.disconnected: ok = true } @@ -188,20 +184,20 @@ func TestBoltTransportDispatch(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() - s.topics = []string{"https://example.com/foo"} - s.rawTopics = s.topics + s := newSubscriber("") + s.Topics = []string{"https://example.com/foo"} + s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) assert.Nil(t, err) - u := &Update{Topics: s.topics} + u := &Update{Topics: s.Topics} err = transport.Dispatch(u) assert.Nil(t, err) - readUpdate := <-s.Out + readUpdate := <-s.Receive() assert.Equal(t, u, readUpdate) } @@ -213,9 +209,9 @@ func TestBoltTransportClosed(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() - s.topics = []string{"https://example.com/foo"} - s.rawTopics = s.topics + s := newSubscriber("") + s.Topics = []string{"https://example.com/foo"} + s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) @@ -227,7 +223,7 @@ func TestBoltTransportClosed(t *testing.T) { err = transport.AddSubscriber(s) assert.Equal(t, err, ErrClosedTransport) - err = transport.Dispatch(&Update{Topics: s.topics}) + err = transport.Dispatch(&Update{Topics: s.Topics}) assert.Equal(t, err, ErrClosedTransport) _, ok := <-s.disconnected @@ -241,12 +237,12 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := newSubscriber() + s1 := newSubscriber("") go s1.start() err := transport.AddSubscriber(s1) require.Nil(t, err) - s2 := newSubscriber() + s2 := newSubscriber("") go s2.start() err = transport.AddSubscriber(s2) require.Nil(t, err) @@ -256,7 +252,7 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { s1.Disconnect() assert.Len(t, transport.subscribers, 2) - transport.Dispatch(&Update{Topics: s1.topics}) + transport.Dispatch(&Update{Topics: s1.Topics}) assert.Len(t, transport.subscribers, 1) s2.Disconnect() diff --git a/hub/log.go b/hub/log.go index 422f3235..12877749 100644 --- a/hub/log.go +++ b/hub/log.go @@ -21,8 +21,8 @@ func addUpdateFields(f log.Fields, u *Update, debug bool) log.Fields { } func createFields(u *Update, s *Subscriber) log.Fields { - f := addUpdateFields(log.Fields{}, u, s.debug) - for k, v := range s.logFields { + f := addUpdateFields(log.Fields{}, u, s.Debug) + for k, v := range s.LogFields { f[k] = v } diff --git a/hub/metrics.go b/hub/metrics.go index 78821ceb..c81a0932 100644 --- a/hub/metrics.go +++ b/hub/metrics.go @@ -59,7 +59,7 @@ func (m *Metrics) Register(r *mux.Router) { // NewSubscriber collects metrics about new subscriber events. func (m *Metrics) NewSubscriber(s *Subscriber) { - for _, t := range s.topics { + for _, t := range s.Topics { m.subscribersTotal.WithLabelValues(t).Inc() m.subscribers.WithLabelValues(t).Inc() } @@ -67,7 +67,7 @@ func (m *Metrics) NewSubscriber(s *Subscriber) { // SubscriberDisconnect collects metrics about subscriber disconnection events. func (m *Metrics) SubscriberDisconnect(s *Subscriber) { - for _, t := range s.topics { + for _, t := range s.Topics { m.subscribers.WithLabelValues(t).Dec() } } diff --git a/hub/metrics_test.go b/hub/metrics_test.go index af7a2f14..42c6f864 100644 --- a/hub/metrics_test.go +++ b/hub/metrics_test.go @@ -11,14 +11,14 @@ import ( func TestNumberOfRunningSubscribers(t *testing.T) { m := NewMetrics() - s1 := newSubscriber() - s1.topics = []string{"topic1", "topic2"} + s1 := newSubscriber("") + s1.Topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") assertGaugeLabelValue(t, 1.0, m.subscribers, "topic2") - s2 := newSubscriber() - s2.topics = []string{"topic2"} + s2 := newSubscriber("") + s2.Topics = []string{"topic2"} m.NewSubscriber(s2) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") assertGaugeLabelValue(t, 2.0, m.subscribers, "topic2") @@ -35,14 +35,14 @@ func TestNumberOfRunningSubscribers(t *testing.T) { func TestTotalNumberOfHandledSubscribers(t *testing.T) { m := NewMetrics() - s1 := newSubscriber() - s1.topics = []string{"topic1", "topic2"} + s1 := newSubscriber("") + s1.Topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") assertCounterValue(t, 1.0, m.subscribersTotal, "topic2") - s2 := newSubscriber() - s2.topics = []string{"topic2"} + s2 := newSubscriber("") + s2.Topics = []string{"topic2"} m.NewSubscriber(s2) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") assertCounterValue(t, 2.0, m.subscribersTotal, "topic2") diff --git a/hub/publish_test.go b/hub/publish_test.go index bf671374..6ef5a9c4 100644 --- a/hub/publish_test.go +++ b/hub/publish_test.go @@ -155,10 +155,10 @@ func TestPublishOK(t *testing.T) { hub := createDummy() defer hub.Stop() - s := newSubscriber() - s.topics = []string{"http://example.com/books/1"} - s.rawTopics = s.topics - s.targets = map[string]struct{}{"foo": {}} + s := newSubscriber("") + s.Topics = []string{"http://example.com/books/1"} + s.RawTopics = s.Topics + s.Targets = map[string]struct{}{"foo": {}} go s.start() err := hub.transport.AddSubscriber(s) @@ -168,7 +168,7 @@ func TestPublishOK(t *testing.T) { wg.Add(1) go func(w *sync.WaitGroup) { defer w.Done() - u, ok := <-s.Out + u, ok := <-s.Receive() assert.True(t, ok) require.NotNil(t, u) assert.Equal(t, "id", u.ID) @@ -206,10 +206,10 @@ func TestPublishGenerateUUID(t *testing.T) { h := createDummy() defer h.Stop() - s := newSubscriber() - s.topics = []string{"http://example.com/books/1"} - s.rawTopics = s.topics - s.targets = map[string]struct{}{"foo": {}} + s := newSubscriber("") + s.Topics = []string{"http://example.com/books/1"} + s.RawTopics = s.Topics + s.Targets = map[string]struct{}{"foo": {}} go s.start() h.transport.AddSubscriber(s) @@ -218,7 +218,7 @@ func TestPublishGenerateUUID(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - u := <-s.Out + u := <-s.Receive() require.NotNil(t, u) _, err := uuid.FromString(u.ID) diff --git a/hub/subscribe.go b/hub/subscribe.go index 7970dd97..cd39b6ae 100644 --- a/hub/subscribe.go +++ b/hub/subscribe.go @@ -49,7 +49,7 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { for { select { - case <-s.disconnected: + case <-s.Disconnected(): // Server closes the connection return case <-r.Context().Done(): @@ -61,7 +61,7 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { return } timer.Reset(hearthbeatInterval) - case update := <-s.Out: + case update := <-s.Receive(): if !h.write(w, r, s, newSerializedUpdate(update).event) { return } @@ -78,52 +78,52 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { // registerSubscriber initializes the connection. func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug bool) *Subscriber { - s := newSubscriber() - s.debug = debug - s.logFields["remote_addr"] = r.RemoteAddr + s := newSubscriber("") + s.Debug = debug + s.LogFields["remote_addr"] = r.RemoteAddr claims, err := authorize(r, h.getJWTKey(subscriberRole), h.getJWTAlgorithm(subscriberRole), nil) if claims != nil { - s.claims = claims - s.logFields["subscriber_targets"] = claims.Mercure.Subscribe + s.Claims = claims + s.LogFields["subscriber_targets"] = claims.Mercure.Subscribe } if err != nil || (claims == nil && !h.config.GetBool("allow_anonymous")) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - log.WithFields(s.logFields).Info(err) + log.WithFields(s.LogFields).Info(err) return nil } - s.topics = r.URL.Query()["topic"] - if len(s.topics) == 0 { + s.Topics = r.URL.Query()["topic"] + if len(s.Topics) == 0 { http.Error(w, "Missing \"topic\" parameter.", http.StatusBadRequest) return nil } - s.logFields["subscriber_topics"] = s.topics + s.LogFields["subscriber_topics"] = s.Topics - s.rawTopics, s.templateTopics = h.parseTopics(s.topics) - s.escapedTopics = escapeTopics(s.topics) - s.allTargets, s.targets = authorizedTargets(claims, false) - s.remoteAddr = r.RemoteAddr + s.RawTopics, s.TemplateTopics = h.parseTopics(s.Topics) + s.EscapedTopics = escapeTopics(s.Topics) + s.AllTargets, s.Targets = authorizedTargets(claims, false) + s.RemoteAddr = r.RemoteAddr - s.lastEventID = retrieveLastEventID(r) - if s.lastEventID != "" { - s.History.In = make(chan *Update) - s.logFields["last_event_id"] = s.lastEventID + s.LastEventID = retrieveLastEventID(r) + if s.LastEventID != "" { + s.history.In = make(chan *Update) + s.LogFields["last_event_id"] = s.LastEventID } go s.start() if h.config.GetBool("subscriptions_include_ip") { - s.remoteHost, _, _ = net.SplitHostPort(r.RemoteAddr) + s.RemoteHost, _, _ = net.SplitHostPort(r.RemoteAddr) } h.dispatchSubscriptionUpdate(s, true) if h.transport.AddSubscriber(s) != nil { http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) h.dispatchSubscriptionUpdate(s, false) - log.WithFields(s.logFields).Error(err) + log.WithFields(s.LogFields).Error(err) return nil } sendHeaders(w) - log.WithFields(s.logFields).Info("New subscriber") + log.WithFields(s.LogFields).Info("New subscriber") h.metrics.NewSubscriber(s) @@ -216,7 +216,7 @@ func (h *Hub) write(w io.Writer, r *http.Request, s *Subscriber, data string) bo case <-done: return true case <-time.After(d): - log.WithFields(s.logFields).Warn("Dispatch timeout reached") + log.WithFields(s.LogFields).Warn("Dispatch timeout reached") return false } } @@ -225,13 +225,13 @@ func (h *Hub) shutdown(s *Subscriber) { // Notify that the client is closing the connection s.Disconnect() h.dispatchSubscriptionUpdate(s, false) - log.WithFields(s.logFields).Info("Subscriber disconnected") + log.WithFields(s.LogFields).Info("Subscriber disconnected") h.metrics.SubscriberDisconnect(s) // Remove unused uritemplate.Template instances from memory. - keys := make([]string, 0, len(s.rawTopics)+len(s.templateTopics)) - copy(s.rawTopics, keys) - for _, uriTemplate := range s.templateTopics { + keys := make([]string, 0, len(s.RawTopics)+len(s.TemplateTopics)) + copy(s.RawTopics, keys) + for _, uriTemplate := range s.TemplateTopics { keys = append(keys, uriTemplate.Raw()) } @@ -252,16 +252,16 @@ func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { return } - for k, topic := range s.topics { + for k, topic := range s.Topics { connection := &subscription{ - ID: "https://mercure.rocks/subscriptions/" + s.escapedTopics[k] + "/" + s.ID, + ID: "https://mercure.rocks/subscriptions/" + s.EscapedTopics[k] + "/" + s.ID, Type: "https://mercure.rocks/Subscription", Topic: topic, Active: active, - Address: s.remoteHost, + Address: s.RemoteHost, } - if s.claims == nil { + if s.Claims == nil { connection.mercureClaim.Publish = []string{} connection.mercureClaim.Subscribe = []string{} } else { @@ -280,7 +280,7 @@ func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { u := &Update{ Topics: []string{connection.ID}, - Targets: map[string]struct{}{"https://mercure.rocks/targets/subscriptions": {}, "https://mercure.rocks/targets/subscriptions/" + s.escapedTopics[k]: {}}, + Targets: map[string]struct{}{"https://mercure.rocks/targets/subscriptions": {}, "https://mercure.rocks/targets/subscriptions/" + s.EscapedTopics[k]: {}}, Event: Event{Data: string(json), ID: uuid.Must(uuid.NewV4()).String()}, } diff --git a/hub/subscriber.go b/hub/subscriber.go index c0af1135..d0396d9a 100644 --- a/hub/subscriber.go +++ b/hub/subscriber.go @@ -13,92 +13,145 @@ type updateSource struct { // Subscriber represents a client subscribed to a list of topics. type Subscriber struct { - ID string - History updateSource - Live updateSource - Out chan *Update - - disconnected chan struct{} - claims *claims - allTargets bool - targets map[string]struct{} - topics []string - escapedTopics []string - rawTopics []string - templateTopics []*uritemplate.Template - lastEventID string - remoteAddr string - remoteHost string - debug bool - - logFields log.Fields - matchCache map[string]bool + ID string + Claims *claims + AllTargets bool + Targets map[string]struct{} + Topics []string + EscapedTopics []string + RawTopics []string + TemplateTopics []*uritemplate.Template + LastEventID string + RemoteAddr string + RemoteHost string + Debug bool + LogFields log.Fields + + history updateSource + live updateSource + out chan *Update + disconnected chan struct{} + matchCache map[string]bool } -func newSubscriber() *Subscriber { +func newSubscriber(lastEventID string) *Subscriber { id := uuid.Must(uuid.NewV4()).String() - return &Subscriber{ + s := &Subscriber{ ID: id, - History: updateSource{}, - Live: updateSource{In: make(chan *Update)}, - Out: make(chan *Update), + LastEventID: lastEventID, + LogFields: log.Fields{"subscriber_id": id}, + history: updateSource{}, + live: updateSource{In: make(chan *Update)}, + out: make(chan *Update), disconnected: make(chan struct{}), - logFields: log.Fields{"subscriber_id": id}, matchCache: make(map[string]bool), } + + if lastEventID != "" { + s.history.In = make(chan *Update) + } + + return s } +// start stores incoming updates in an history and a live buffer and dispatch them. +// updates coming from the history are always dispatched first func (s *Subscriber) start() { for { select { case <-s.disconnected: return - case u, ok := <-s.History.In: + case u, ok := <-s.history.In: if !ok { - s.History.In = nil + s.history.In = nil break } if s.CanDispatch(u) { - s.History.buffer = append(s.History.buffer, u) + s.history.buffer = append(s.history.buffer, u) } - case u := <-s.Live.In: + case u := <-s.live.In: if s.CanDispatch(u) { - s.Live.buffer = append(s.Live.buffer, u) + s.live.buffer = append(s.live.buffer, u) } case s.outChan() <- s.nextUpdate(): - if len(s.History.buffer) > 0 { - s.History.buffer = s.History.buffer[1:] + if len(s.history.buffer) > 0 { + s.history.buffer = s.history.buffer[1:] break } - s.Live.buffer = s.Live.buffer[1:] + s.live.buffer = s.live.buffer[1:] } } } -func (s *Subscriber) outChan() chan *Update { - if len(s.Live.buffer) > 0 || len(s.History.buffer) > 0 { - return s.Out +// outChan returns the out channel if buffers aren't empty, or nil to block +func (s *Subscriber) outChan() chan<- *Update { + if len(s.live.buffer) > 0 || len(s.history.buffer) > 0 { + return s.out } return nil } +// nextUpdate returns the next update to dispatch. +// the history is always entirely flushed before starting to dispatch live updates func (s *Subscriber) nextUpdate() *Update { // Always flush the history buffer first to preserve order - if s.History.In != nil || len(s.History.buffer) > 0 { - if len(s.History.buffer) > 0 { - return s.History.buffer[0] + if s.history.In != nil || len(s.history.buffer) > 0 { + if len(s.history.buffer) > 0 { + return s.history.buffer[0] } return nil } - if len(s.Live.buffer) > 0 { - return s.Live.buffer[0] + if len(s.live.buffer) > 0 { + return s.live.buffer[0] } return nil } +// Dispatch an update to the subscriber. +func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { + var in chan<- *Update + if fromHistory { + in = s.history.In + } else { + in = s.live.In + } + + select { + case <-s.disconnected: + return false + case in <- u: + } + + return true +} + +// Receive +func (s *Subscriber) Receive() <-chan *Update { + return s.out +} + +func (s *Subscriber) HistoryDispatched() { + close(s.history.In) +} + +// Disconnect disconnects the subscriber. +func (s *Subscriber) Disconnect() { + select { + case <-s.disconnected: + return + default: + } + + close(s.disconnected) +} + +func (s *Subscriber) Disconnected() <-chan struct{} { + return s.disconnected +} + // CanDispatch checks if an update can be dispatched to this subsriber. func (s *Subscriber) CanDispatch(u *Update) bool { if !s.IsAuthorized(u) { @@ -117,11 +170,11 @@ func (s *Subscriber) CanDispatch(u *Update) bool { // IsAuthorized checks if the subscriber can access to at least one of the update's intended targets. // Don't forget to also call IsSubscribed. func (s *Subscriber) IsAuthorized(u *Update) bool { - if s.allTargets || len(u.Targets) == 0 { + if s.AllTargets || len(u.Targets) == 0 { return true } - for t := range s.targets { + for t := range s.Targets { if _, ok := u.Targets[t]; ok { return true } @@ -141,14 +194,14 @@ func (s *Subscriber) IsSubscribed(u *Update) bool { continue } - for _, rt := range s.rawTopics { + for _, rt := range s.RawTopics { if ut == rt { s.matchCache[ut] = true return true } } - for _, tt := range s.templateTopics { + for _, tt := range s.TemplateTopics { if tt.Match(ut) != nil { s.matchCache[ut] = true return true @@ -160,31 +213,3 @@ func (s *Subscriber) IsSubscribed(u *Update) bool { return false } - -// Dispatch an update to the subscriber. -func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { - var in chan<- *Update - if fromHistory { - in = s.History.In - } else { - in = s.Live.In - } - - select { - case <-s.disconnected: - return false - case in <- u: - } - - return true -} - -func (s *Subscriber) Disconnect() { - select { - case <-s.disconnected: - return - default: - } - - close(s.disconnected) -} diff --git a/hub/subscriber_test.go b/hub/subscriber_test.go index e2fdfb4c..ebd93605 100644 --- a/hub/subscriber_test.go +++ b/hub/subscriber_test.go @@ -1,15 +1,16 @@ package hub import ( + "strconv" "testing" "github.com/stretchr/testify/assert" ) func TestIsSubscribed(t *testing.T) { - s := newSubscriber() - s.topics = []string{"foo", "bar"} - s.rawTopics = s.topics + s := newSubscriber("") + s.Topics = []string{"foo", "bar"} + s.RawTopics = s.Topics assert.Len(t, s.matchCache, 0) assert.False(t, s.IsSubscribed(&Update{Topics: []string{"baz", "bat"}})) @@ -20,3 +21,33 @@ func TestIsSubscribed(t *testing.T) { assert.True(t, s.IsSubscribed(&Update{Topics: []string{"bar", "qux"}})) assert.Len(t, s.matchCache, 3) } + +func TestDispatch(t *testing.T) { + s := newSubscriber("1") + s.Topics = []string{"http://example.com"} + s.RawTopics = s.Topics + go s.start() + defer s.Disconnect() + + // Dispatch must be non-blocking + // Messages comming from the history can be sent after live messages, but must be received first + s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "3"}}, false) + s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "1"}}, true) + s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "4"}}, false) + s.Dispatch(&Update{Topics: s.Topics, Event: Event{ID: "2"}}, true) + s.HistoryDispatched() + + for i := 1; i <= 4; i++ { + u := <-s.Receive() + assert.Equal(t, strconv.Itoa(i), u.ID) + } +} + +func TestDisconnect(t *testing.T) { + s := newSubscriber("") + s.Disconnect() + // can be called two times without crashing + s.Disconnect() + + assert.False(t, s.Dispatch(&Update{}, false)) +} diff --git a/hub/transport_test.go b/hub/transport_test.go index 841da188..9d3b0e76 100644 --- a/hub/transport_test.go +++ b/hub/transport_test.go @@ -21,10 +21,10 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { err := transport.Dispatch(u) require.Nil(t, err) - s := newSubscriber() - s.topics = u.Topics - s.rawTopics = u.Topics - s.targets = map[string]struct{}{"foo": {}} + s := newSubscriber("") + s.Topics = u.Topics + s.RawTopics = u.Topics + s.Targets = map[string]struct{}{"foo": {}} go s.start() err = transport.AddSubscriber(s) @@ -39,7 +39,7 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { go func() { defer wg.Done() select { - case readUpdate = <-s.Out: + case readUpdate = <-s.Receive(): case <-s.disconnected: ok = true } @@ -57,20 +57,20 @@ func TestLocalTransportDispatch(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() - s.topics = []string{"http://example.com/foo"} - s.rawTopics = s.topics + s := newSubscriber("") + s.Topics = []string{"http://example.com/foo"} + s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) assert.Nil(t, err) - u := &Update{Topics: s.topics} + u := &Update{Topics: s.Topics} err = transport.Dispatch(u) assert.Nil(t, err) - readUpdate := <-s.Out + readUpdate := <-s.Receive() assert.Equal(t, u, readUpdate) } @@ -79,14 +79,14 @@ func TestLocalTransportClosed(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() + s := newSubscriber("") err := transport.AddSubscriber(s) require.Nil(t, err) err = transport.Close() assert.Nil(t, err) - err = transport.AddSubscriber(newSubscriber()) + err = transport.AddSubscriber(newSubscriber("")) assert.Equal(t, err, ErrClosedTransport) err = transport.Dispatch(&Update{}) @@ -100,13 +100,13 @@ func TestLiveCleanDisconnectedSubscribers(t *testing.T) { transport := NewLocalTransport() defer transport.Close() - s1 := newSubscriber() + s1 := newSubscriber("") go s1.start() err := transport.AddSubscriber(s1) require.Nil(t, err) - s2 := newSubscriber() + s2 := newSubscriber("") go s2.start() err = transport.AddSubscriber(s2) @@ -117,7 +117,7 @@ func TestLiveCleanDisconnectedSubscribers(t *testing.T) { s1.Disconnect() assert.Len(t, transport.subscribers, 2) - transport.Dispatch(&Update{Topics: s1.topics}) + transport.Dispatch(&Update{Topics: s1.Topics}) assert.Len(t, transport.subscribers, 1) s2.Disconnect() @@ -132,19 +132,19 @@ func TestLiveReading(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber() - s.topics = []string{"https://example.com"} - s.rawTopics = s.topics + s := newSubscriber("") + s.Topics = []string{"https://example.com"} + s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) assert.Nil(t, err) - u := &Update{Topics: s.topics} + u := &Update{Topics: s.Topics} err = transport.Dispatch(u) assert.Nil(t, err) - receivedUpdate := <-s.Out + receivedUpdate := <-s.Receive() assert.Equal(t, u, receivedUpdate) }