From 0327c3a9f84b236eee3edc4ff66f1acab907b998 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Dunglas?= Date: Wed, 6 May 2020 15:55:02 +0200 Subject: [PATCH] Refactor and fix tests --- hub/authorization_test.go | 5 +- hub/bolt_transport.go | 14 +-- hub/bolt_transport_test.go | 191 ++++++++++++++++++++----------------- hub/config_test.go | 2 +- hub/hub_test.go | 4 +- hub/log.go | 44 ++++----- hub/metrics.go | 4 +- hub/metrics_test.go | 12 ++- hub/publish.go | 2 +- hub/publish_test.go | 39 +++++--- hub/server_test.go | 14 +-- hub/subscribe.go | 161 ++++++++++++++++--------------- hub/subscribe_test.go | 67 ++++++------- hub/subscriber.go | 137 +++++++++++++------------- hub/subscriber_test.go | 4 +- hub/transport.go | 2 +- hub/transport_test.go | 168 ++++++++++++++++---------------- 17 files changed, 439 insertions(+), 431 deletions(-) diff --git a/hub/authorization_test.go b/hub/authorization_test.go index 5ec2f7a9..8a66edee 100644 --- a/hub/authorization_test.go +++ b/hub/authorization_test.go @@ -3,7 +3,6 @@ package hub import ( "net/http" "testing" - "time" "github.com/spf13/viper" "github.com/stretchr/testify/assert" @@ -393,7 +392,7 @@ func TestAuthorizedAllTargetsSubscriber(t *testing.T) { func TestGetJWTKeyInvalid(t *testing.T) { v := viper.New() - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) h.config.Set("publisher_jwt_key", "") assert.PanicsWithValue(t, "one of these configuration parameters must be defined: [publisher_jwt_key jwt_key]", func() { @@ -408,7 +407,7 @@ func TestGetJWTKeyInvalid(t *testing.T) { func TestGetJWTAlgorithmInvalid(t *testing.T) { v := viper.New() - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) h.config.Set("publisher_jwt_algorithm", "foo") assert.PanicsWithValue(t, "invalid signing method: foo", func() { diff --git a/hub/bolt_transport.go b/hub/bolt_transport.go index 4b4a8c9b..e953ab43 100644 --- a/hub/bolt_transport.go +++ b/hub/bolt_transport.go @@ -139,29 +139,29 @@ func (t *BoltTransport) persist(updateID string, updateJSON []byte) error { // AddSubscriber adds a new subscriber to the transport. func (t *BoltTransport) AddSubscriber(s *Subscriber) error { - t.Lock() - defer t.Unlock() - select { case <-t.done: return ErrClosedTransport default: } + t.Lock() t.subscribers[s] = struct{}{} - if s.LastEventID == "" { + if s.History.In == nil { + t.Unlock() return nil } + t.Unlock() toSeq := t.lastSeq.Load() - t.dispatchFromHistory(s.LastEventID, toSeq, s) + t.dispatchFromHistory(s.lastEventID, toSeq, s) return nil } func (t *BoltTransport) dispatchFromHistory(lastEventID string, toSeq uint64, s *Subscriber) { t.db.View(func(tx *bolt.Tx) error { - defer close(s.HistorySrc.In) + defer close(s.History.In) b := tx.Bucket([]byte(t.bucketName)) if b == nil { return nil // No data @@ -204,7 +204,7 @@ func (t *BoltTransport) Close() error { t.Lock() defer t.Unlock() for subscriber := range t.subscribers { - close(subscriber.ServerDisconnect) + subscriber.Disconnect() delete(t.subscribers, subscriber) } close(t.done) diff --git a/hub/bolt_transport_test.go b/hub/bolt_transport_test.go index 3ff6b94f..e7f70d93 100644 --- a/hub/bolt_transport_test.go +++ b/hub/bolt_transport_test.go @@ -1,13 +1,11 @@ package hub import ( - "context" "net/url" "os" "strconv" "sync" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,21 +14,31 @@ import ( func TestBoltTransportHistory(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer transport.Close() defer os.Remove("test.db") + topics := []string{"https://example.com/foo"} for i := 1; i <= 10; i++ { - transport.Write(&Update{Event: Event{ID: strconv.Itoa(i)}}) + transport.Dispatch(&Update{ + Event: Event{ID: strconv.Itoa(i)}, + Topics: topics, + }) } - pipe, err := transport.CreatePipe("8") + s := newSubscriber() + s.topics = topics + s.rawTopics = topics + s.lastEventID = "8" + s.History.In = make(chan *Update) + go s.start() + + err := transport.AddSubscriber(s) assert.Nil(t, err) - require.NotNil(t, pipe) var count int for { - u := <-pipe.Read() + u := <-s.Out // the reading loop must read the #9 and #10 messages assert.Equal(t, strconv.Itoa(9+count), u.ID) count++ @@ -42,51 +50,64 @@ func TestBoltTransportHistory(t *testing.T) { func TestBoltTransportHistoryAndLive(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer transport.Close() defer os.Remove("test.db") + topics := []string{"https://example.com/foo"} for i := 1; i <= 10; i++ { - transport.Write(&Update{Event: Event{ID: strconv.Itoa(i)}}) + transport.Dispatch(&Update{ + Topics: topics, + Event: Event{ID: strconv.Itoa(i)}, + }) } - pipe, err := transport.CreatePipe("8") + s := newSubscriber() + s.topics = topics + s.rawTopics = topics + s.lastEventID = "8" + s.History.In = make(chan *Update) + go s.start() + + err := transport.AddSubscriber(s) assert.Nil(t, err) - require.NotNil(t, pipe) var wg sync.WaitGroup wg.Add(1) go func() { + defer wg.Done() var count int for { - u, ok := <-pipe.Read() - if !ok { - return - } + u := <-s.Out // the reading loop must read the #9, #10 and #11 messages assert.Equal(t, strconv.Itoa(9+count), u.ID) count++ if count == 3 { - wg.Done() return } } }() - transport.Write(&Update{Event: Event{ID: "11"}}) + transport.Dispatch(&Update{ + Event: Event{ID: "11"}, + Topics: topics, + }) wg.Wait() } func TestBoltTransportPurgeHistory(t *testing.T) { u, _ := url.Parse("bolt://test.db?size=5&cleanup_frequency=1") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer transport.Close() defer os.Remove("test.db") for i := 0; i < 12; i++ { - transport.Write(&Update{Event: Event{ID: strconv.Itoa(i)}}) + transport.Dispatch(&Update{ + Event: Event{ID: strconv.Itoa(i)}, + Topics: []string{"https://example.com/foo"}, + }) } transport.db.View(func(tx *bolt.Tx) error { @@ -100,151 +121,147 @@ func TestBoltTransportPurgeHistory(t *testing.T) { func TestNewBoltTransport(t *testing.T) { u, _ := url.Parse("bolt://test.db?bucket_name=demo") - transport, err := NewBoltTransport(u, 5, time.Second) + transport, err := NewBoltTransport(u) assert.Nil(t, err) require.NotNil(t, transport) transport.Close() u, _ = url.Parse("bolt://") - _, err = NewBoltTransport(u, 5, time.Second) + _, err = NewBoltTransport(u) assert.EqualError(t, err, `invalid bolt DSN "bolt:": missing path`) u, _ = url.Parse("bolt:///test.db") - _, err = NewBoltTransport(u, 5, time.Second) + _, err = NewBoltTransport(u) // The exact error message depends of the OS assert.Contains(t, err.Error(), `invalid bolt DSN "bolt:///test.db": open /test.db: `) u, _ = url.Parse("bolt://test.db?cleanup_frequency=invalid") - _, err = NewBoltTransport(u, 5, time.Second) + _, err = NewBoltTransport(u) assert.EqualError(t, err, `invalid bolt "bolt://test.db?cleanup_frequency=invalid" dsn: parameter cleanup_frequency: strconv.ParseFloat: parsing "invalid": invalid syntax`) u, _ = url.Parse("bolt://test.db?size=invalid") - _, err = NewBoltTransport(u, 5, time.Second) + _, err = NewBoltTransport(u) assert.EqualError(t, err, `invalid bolt "bolt://test.db?size=invalid" dsn: parameter size: strconv.ParseUint: parsing "invalid": invalid syntax`) } -func TestBoltTransportWriteIsNotDispatchedUntilListen(t *testing.T) { +func TestBoltTransportDoNotDispatchedUntilListen(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer transport.Close() defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - pipe, err := transport.CreatePipe("") + s := newSubscriber() + go s.start() + + err := transport.AddSubscriber(s) assert.Nil(t, err) - require.NotNil(t, pipe) var ( readUpdate *Update ok bool - m sync.Mutex wg sync.WaitGroup ) wg.Add(1) go func() { - m.Lock() - defer m.Unlock() - go wg.Done() - select { - case readUpdate = <-pipe.Read(): - case <-pipe.done: + case readUpdate = <-s.Out: + case <-s.disconnected: ok = true } + + wg.Done() }() - wg.Wait() - pipe.Close() + s.Disconnect() - m.Lock() - defer m.Unlock() + wg.Wait() assert.Nil(t, readUpdate) assert.True(t, ok) } -func TestBoltTransportWriteIsDispatched(t *testing.T) { - u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) +func TestBoltTransportDispatch(t *testing.T) { + ur, _ := url.Parse("bolt://test.db") + transport, _ := NewBoltTransport(ur) defer transport.Close() defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - pipe, err := transport.CreatePipe("") - assert.Nil(t, err) - require.NotNil(t, pipe) - defer pipe.Close() + s := newSubscriber() + s.topics = []string{"https://example.com/foo"} + s.rawTopics = s.topics + go s.start() - var ( - readUpdate *Update - ok bool - m sync.Mutex - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - m.Lock() - defer m.Unlock() + err := transport.AddSubscriber(s) + assert.Nil(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - go wg.Done() - select { - case readUpdate, ok = <-pipe.Read(): - case <-ctx.Done(): - } - }() + u := &Update{Topics: s.topics} - wg.Wait() - err = transport.Write(&Update{}) + err = transport.Dispatch(u) assert.Nil(t, err) - m.Lock() - defer m.Unlock() - - assert.True(t, ok) - assert.NotNil(t, readUpdate) + readUpdate := <-s.Out + assert.Equal(t, u, readUpdate) } func TestBoltTransportClosed(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) require.NotNil(t, transport) defer transport.Close() defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - pipe, _ := transport.CreatePipe("") - require.NotNil(t, pipe) + s := newSubscriber() + s.topics = []string{"https://example.com/foo"} + s.rawTopics = s.topics + go s.start() + + err := transport.AddSubscriber(s) + require.Nil(t, err) - err := transport.Close() + err = transport.Close() assert.Nil(t, err) - _, err = transport.CreatePipe("") + err = transport.AddSubscriber(s) assert.Equal(t, err, ErrClosedTransport) - err = transport.Write(&Update{}) + err = transport.Dispatch(&Update{Topics: s.topics}) assert.Equal(t, err, ErrClosedTransport) - _, ok := <-pipe.Read() + _, ok := <-s.disconnected assert.False(t, ok) } -func TestBoltCleanClosedPipes(t *testing.T) { +func TestBoltCleanDisconnectedSubscribers(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) require.NotNil(t, transport) defer transport.Close() defer os.Remove("test.db") - pipe, _ := transport.CreatePipe("") - require.NotNil(t, pipe) + s1 := newSubscriber() + go s1.start() + err := transport.AddSubscriber(s1) + require.Nil(t, err) + + s2 := newSubscriber() + go s2.start() + err = transport.AddSubscriber(s2) + require.Nil(t, err) + + assert.Len(t, transport.subscribers, 2) + + s1.Disconnect() + assert.Len(t, transport.subscribers, 2) - assert.Len(t, transport.pipes, 1) + transport.Dispatch(&Update{Topics: s1.topics}) + assert.Len(t, transport.subscribers, 1) - pipe.Close() - assert.Len(t, transport.pipes, 1) + s2.Disconnect() + assert.Len(t, transport.subscribers, 1) - transport.Write(&Update{}) - assert.Len(t, transport.pipes, 0) + transport.Dispatch(&Update{}) + assert.Len(t, transport.subscribers, 0) } diff --git a/hub/config_test.go b/hub/config_test.go index c604e54c..967fb8d8 100644 --- a/hub/config_test.go +++ b/hub/config_test.go @@ -37,7 +37,7 @@ func TestSetFlags(t *testing.T) { fs := pflag.NewFlagSet("test", pflag.PanicOnError) SetFlags(fs, v) - assert.Subset(t, v.AllKeys(), []string{"cert_file", "compress", "demo", "jwt_algorithm", "transport_url", "acme_hosts", "acme_cert_dir", "subscriber_jwt_key", "log_format", "jwt_key", "allow_anonymous", "debug", "read_timeout", "publisher_jwt_algorithm", "write_timeout", "key_file", "use_forwarded_headers", "subscriber_jwt_algorithm", "addr", "publisher_jwt_key", "heartbeat_interval", "cors_allowed_origins", "publish_allowed_origins", "dispatch_subscriptions", "subscriptions_include_ip", "metrics", "update_buffer_size", "update_buffer_full_timeout"}) + assert.Subset(t, v.AllKeys(), []string{"cert_file", "compress", "demo", "jwt_algorithm", "transport_url", "acme_hosts", "acme_cert_dir", "subscriber_jwt_key", "log_format", "jwt_key", "allow_anonymous", "debug", "read_timeout", "publisher_jwt_algorithm", "write_timeout", "key_file", "use_forwarded_headers", "subscriber_jwt_algorithm", "addr", "publisher_jwt_key", "heartbeat_interval", "cors_allowed_origins", "publish_allowed_origins", "dispatch_subscriptions", "subscriptions_include_ip", "metrics", "dispatch_timeout"}) } func TestInitConfig(t *testing.T) { diff --git a/hub/hub_test.go b/hub/hub_test.go index f00cce4a..e13f4311 100644 --- a/hub/hub_test.go +++ b/hub/hub_test.go @@ -69,11 +69,11 @@ func createDummy() *Hub { v.SetDefault("publisher_jwt_key", "publisher") v.SetDefault("subscriber_jwt_key", "subscriber") - return NewHubWithTransport(v, NewLocalTransport(5, time.Second)) + return NewHubWithTransport(v, NewLocalTransport()) } func createAnonymousDummy() *Hub { - return createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), viper.New()) + return createDummyWithTransportAndConfig(NewLocalTransport(), viper.New()) } func createDummyWithTransportAndConfig(t Transport, v *viper.Viper) *Hub { diff --git a/hub/log.go b/hub/log.go index f935ede9..422f3235 100644 --- a/hub/log.go +++ b/hub/log.go @@ -1,47 +1,35 @@ package hub import ( - "net/http" - fluentd "github.com/joonix/log" log "github.com/sirupsen/logrus" "github.com/spf13/viper" ) -// TODO: delete me -func (h *Hub) createLogFields(r *http.Request, u *Update, s *Subscriber) log.Fields { - return createBaseLogFields(h.config.GetBool("debug"), r.RemoteAddr, u, s) -} +func addUpdateFields(f log.Fields, u *Update, debug bool) log.Fields { + f["event_id"] = u.ID + f["event_type"] = u.Type + f["event_retry"] = u.Retry + f["update_topics"] = u.Topics + f["update_targets"] = targetsMapToSlice(u.Targets) -// TODO: rename me -func createBaseLogFields(debug bool, remoteAddr string, u *Update, s *Subscriber) log.Fields { - fields := log.Fields{ - "remote_addr": remoteAddr, + if debug { + f["update_data"] = u.Data } - if u != nil { - fields["event_id"] = u.ID - fields["event_type"] = u.Type - fields["event_retry"] = u.Retry - fields["update_topics"] = u.Topics - fields["update_targets"] = targetsMapToArray(u.Targets) - - if debug { - fields["update_data"] = u.Data - } - - } + return f +} - if s != nil { - fields["last_event_id"] = s.LastEventID - fields["subscriber_topics"] = s.Topics - fields["subscriber_targets"] = targetsMapToArray(s.Targets) +func createFields(u *Update, s *Subscriber) log.Fields { + f := addUpdateFields(log.Fields{}, u, s.debug) + for k, v := range s.logFields { + f[k] = v } - return fields + return f } -func targetsMapToArray(t map[string]struct{}) []string { +func targetsMapToSlice(t map[string]struct{}) []string { targets := make([]string, len(t)) var i int diff --git a/hub/metrics.go b/hub/metrics.go index c81a0932..78821ceb 100644 --- a/hub/metrics.go +++ b/hub/metrics.go @@ -59,7 +59,7 @@ func (m *Metrics) Register(r *mux.Router) { // NewSubscriber collects metrics about new subscriber events. func (m *Metrics) NewSubscriber(s *Subscriber) { - for _, t := range s.Topics { + for _, t := range s.topics { m.subscribersTotal.WithLabelValues(t).Inc() m.subscribers.WithLabelValues(t).Inc() } @@ -67,7 +67,7 @@ func (m *Metrics) NewSubscriber(s *Subscriber) { // SubscriberDisconnect collects metrics about subscriber disconnection events. func (m *Metrics) SubscriberDisconnect(s *Subscriber) { - for _, t := range s.Topics { + for _, t := range s.topics { m.subscribers.WithLabelValues(t).Dec() } } diff --git a/hub/metrics_test.go b/hub/metrics_test.go index 782c7641..af7a2f14 100644 --- a/hub/metrics_test.go +++ b/hub/metrics_test.go @@ -11,12 +11,14 @@ import ( func TestNumberOfRunningSubscribers(t *testing.T) { m := NewMetrics() - s1 := NewSubscriber(false, nil, []string{"topic1", "topic2"}, []string{"topic1", "topic2"}, nil, "lid1") + s1 := newSubscriber() + s1.topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") assertGaugeLabelValue(t, 1.0, m.subscribers, "topic2") - s2 := NewSubscriber(false, nil, []string{"topic2"}, []string{"topic2"}, nil, "lid2") + s2 := newSubscriber() + s2.topics = []string{"topic2"} m.NewSubscriber(s2) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") assertGaugeLabelValue(t, 2.0, m.subscribers, "topic2") @@ -33,12 +35,14 @@ func TestNumberOfRunningSubscribers(t *testing.T) { func TestTotalNumberOfHandledSubscribers(t *testing.T) { m := NewMetrics() - s1 := NewSubscriber(false, nil, []string{"topic1", "topic2"}, []string{"topic1", "topic2"}, nil, "lid1") + s1 := newSubscriber() + s1.topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") assertCounterValue(t, 1.0, m.subscribersTotal, "topic2") - s2 := NewSubscriber(false, nil, []string{"topic2"}, []string{"topic2"}, nil, "lid2") + s2 := newSubscriber() + s2.topics = []string{"topic2"} m.NewSubscriber(s2) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") assertCounterValue(t, 2.0, m.subscribersTotal, "topic2") diff --git a/hub/publish.go b/hub/publish.go index e8e90c64..9ed0968b 100644 --- a/hub/publish.go +++ b/hub/publish.go @@ -72,7 +72,7 @@ func (h *Hub) PublishHandler(w http.ResponseWriter, r *http.Request) { } io.WriteString(w, u.ID) - log.WithFields(h.createLogFields(r, u, nil)).Info("Update published") + log.WithFields(addUpdateFields(log.Fields{"remote_addr": r.RemoteAddr}, u, h.config.GetBool("debug"))).Info("Update published") h.metrics.NewUpdate(u) } diff --git a/hub/publish_test.go b/hub/publish_test.go index 1fd84351..bf671374 100644 --- a/hub/publish_test.go +++ b/hub/publish_test.go @@ -153,16 +153,22 @@ func TestPublishNotAuthorizedTarget(t *testing.T) { func TestPublishOK(t *testing.T) { hub := createDummy() + defer hub.Stop() - pipe, err := hub.transport.CreatePipe("") + s := newSubscriber() + s.topics = []string{"http://example.com/books/1"} + s.rawTopics = s.topics + s.targets = map[string]struct{}{"foo": {}} + go s.start() + + err := hub.transport.AddSubscriber(s) assert.Nil(t, err) - require.NotNil(t, pipe) var wg sync.WaitGroup wg.Add(1) go func(w *sync.WaitGroup) { defer w.Done() - u, ok := <-pipe.Read() + u, ok := <-s.Out assert.True(t, ok) require.NotNil(t, u) assert.Equal(t, "id", u.ID) @@ -197,20 +203,25 @@ func TestPublishOK(t *testing.T) { } func TestPublishGenerateUUID(t *testing.T) { - hub := createDummy() + h := createDummy() + defer h.Stop() - pipe, err := hub.transport.CreatePipe("") - assert.Nil(t, err) - require.NotNil(t, pipe) + s := newSubscriber() + s.topics = []string{"http://example.com/books/1"} + s.rawTopics = s.topics + s.targets = map[string]struct{}{"foo": {}} + go s.start() + + h.transport.AddSubscriber(s) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - u, ok := <-pipe.Read() - assert.True(t, ok) + u := <-s.Out require.NotNil(t, u) - _, err = uuid.FromString(u.ID) + + _, err := uuid.FromString(u.ID) assert.Nil(t, err) }() @@ -220,11 +231,11 @@ func TestPublishGenerateUUID(t *testing.T) { req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, publisherRole, []string{})}) - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{})) + //req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, publisherRole, []string{})}) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, publisherRole, []string{})) w := httptest.NewRecorder() - hub.PublishHandler(w, req) + h.PublishHandler(w, req) resp := w.Result() defer resp.Body.Close() @@ -233,7 +244,7 @@ func TestPublishGenerateUUID(t *testing.T) { bodyBytes, _ := ioutil.ReadAll(resp.Body) body := string(bodyBytes) - _, err = uuid.FromString(body) + _, err := uuid.FromString(body) assert.Nil(t, err) wg.Wait() diff --git a/hub/server_test.go b/hub/server_test.go index 309d558a..919dac02 100644 --- a/hub/server_test.go +++ b/hub/server_test.go @@ -25,7 +25,7 @@ const testSecureURL = "https://" + testAddr + defaultHubURL func TestForwardedHeaders(t *testing.T) { v := viper.New() v.Set("use_forwarded_headers", true) - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) go h.Serve() @@ -61,7 +61,7 @@ func TestSecurityOptions(t *testing.T) { v.Set("cert_file", "../fixtures/tls/server.crt") v.Set("key_file", "../fixtures/tls/server.key") v.Set("compress", true) - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) go h.Serve() @@ -171,7 +171,7 @@ func TestServe(t *testing.T) { func TestClientClosesThenReconnects(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer os.Remove("test.db") h := createDummyWithTransportAndConfig(transport, viper.New()) @@ -216,7 +216,7 @@ func TestClientClosesThenReconnects(t *testing.T) { publish := func(data string, waitForSubscribers int) { for { transport.Lock() - l := len(transport.pipes) + l := len(transport.subscribers) transport.Unlock() if l >= waitForSubscribers { break @@ -274,7 +274,7 @@ func TestServeAcme(t *testing.T) { v.Set("acme_hosts", []string{"example.com"}) v.Set("acme_http01_addr", ":8080") v.Set("acme_cert_dir", dir) - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) go h.Serve() @@ -306,7 +306,7 @@ func TestServeAcme(t *testing.T) { func TestMetricsAccess(t *testing.T) { v := viper.New() v.Set("metrics", true) - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) go h.Serve() @@ -360,7 +360,7 @@ type testServer struct { } func newTestServer(t *testing.T, v *viper.Viper) testServer { - h := createDummyWithTransportAndConfig(NewLocalTransport(5, time.Second), v) + h := createDummyWithTransportAndConfig(NewLocalTransport(), v) go func() { h.Serve() diff --git a/hub/subscribe.go b/hub/subscribe.go index 41dc37e2..7970dd97 100644 --- a/hub/subscribe.go +++ b/hub/subscribe.go @@ -1,7 +1,6 @@ package hub import ( - "context" "encoding/json" "fmt" "io" @@ -25,115 +24,110 @@ type subscription struct { Address string `json:"address,omitempty"` } -// SubscribeHandler create a keep alive connection and send the events to the subscribers. +// SubscribeHandler creates a keep alive connection and sends the events to the subscribers. func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { _, ok := w.(http.Flusher) if !ok { panic("http.ResponseWriter must be an instance of http.Flusher") } - subscriber, unsubscribed, ok := h.initSubscription(w, r) - if !ok { + debug := h.config.GetBool("debug") + s := h.registerSubscriber(w, r, debug) + if s == nil { return } - defer h.cleanup(subscriber) - defer unsubscribed() - // Notify that the client is closing the connection - defer close(subscriber.ClientDisconnect) + defer h.shutdown(s) hearthbeatInterval := h.config.GetDuration("heartbeat_interval") - var cancelHearthbeatTimeout context.CancelFunc - for { - ctxHearthbeat := context.Background() - if hearthbeatInterval != time.Duration(0) { - ctxHearthbeat, cancelHearthbeatTimeout = context.WithTimeout(ctxHearthbeat, hearthbeatInterval) - defer cancelHearthbeatTimeout() - } + var timer *time.Timer + var timerC <-chan time.Time + if hearthbeatInterval != time.Duration(0) { + timer = time.NewTimer(hearthbeatInterval) + timerC = timer.C + } + + for { select { - case <-subscriber.ServerDisconnect: + case <-s.disconnected: // Server closes the connection return case <-r.Context().Done(): // Client closes the connection return - case <-ctxHearthbeat.Done(): - if ctxHearthbeat.Err() != context.DeadlineExceeded { - break - } + case <-timerC: // Send a SSE comment as a heartbeat, to prevent issues with some proxies and old browsers - if !h.write(w, r, subscriber, ":\n") { + if !h.write(w, r, s, ":\n") { return } - case update := <-subscriber.Out: - if !h.write(w, r, subscriber, newSerializedUpdate(update).event) { + timer.Reset(hearthbeatInterval) + case update := <-s.Out: + if !h.write(w, r, s, newSerializedUpdate(update).event) { return } - if cancelHearthbeatTimeout != nil { - cancelHearthbeatTimeout() + if timer != nil { + if !timer.Stop() { + <-timer.C + } + timer.Reset(hearthbeatInterval) } - - fields := createBaseLogFields(subscriber.debug, r.RemoteAddr, update, subscriber) - log.WithFields(fields).Info("Event sent") + log.WithFields(createFields(update, s)).Info("Event sent") } } } -// initSubscription initializes the connection. -func (h *Hub) initSubscription(w http.ResponseWriter, r *http.Request) (*Subscriber, func(), bool) { - fields := log.Fields{"remote_addr": r.RemoteAddr} +// registerSubscriber initializes the connection. +func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug bool) *Subscriber { + s := newSubscriber() + s.debug = debug + s.logFields["remote_addr"] = r.RemoteAddr claims, err := authorize(r, h.getJWTKey(subscriberRole), h.getJWTAlgorithm(subscriberRole), nil) - debug := h.config.GetBool("debug") - if debug && claims != nil { - fields["target"] = claims.Mercure.Subscribe + if claims != nil { + s.claims = claims + s.logFields["subscriber_targets"] = claims.Mercure.Subscribe } if err != nil || (claims == nil && !h.config.GetBool("allow_anonymous")) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - log.WithFields(fields).Info(err) - return nil, nil, false + log.WithFields(s.logFields).Info(err) + return nil } - topics := r.URL.Query()["topic"] - if len(topics) == 0 { + s.topics = r.URL.Query()["topic"] + if len(s.topics) == 0 { http.Error(w, "Missing \"topic\" parameter.", http.StatusBadRequest) - return nil, nil, false + return nil } - fields["subscriber_topics"] = topics + s.logFields["subscriber_topics"] = s.topics - rawTopics, templateTopics := h.parseTopics(topics) + s.rawTopics, s.templateTopics = h.parseTopics(s.topics) + s.escapedTopics = escapeTopics(s.topics) + s.allTargets, s.targets = authorizedTargets(claims, false) + s.remoteAddr = r.RemoteAddr - authorizedAlltargets, authorizedTargets := authorizedTargets(claims, false) - subscriber := NewSubscriber(authorizedAlltargets, authorizedTargets, topics, rawTopics, templateTopics, retrieveLastEventID(r), r.RemoteAddr, debug) - encodedTopics := escapeTopics(topics) + s.lastEventID = retrieveLastEventID(r) + if s.lastEventID != "" { + s.History.In = make(chan *Update) + s.logFields["last_event_id"] = s.lastEventID + } + go s.start() - // TODO: move this to the subscriber struct - connectionID := uuid.Must(uuid.NewV4()).String() - var address string if h.config.GetBool("subscriptions_include_ip") { - address, _, _ = net.SplitHostPort(r.RemoteAddr) + s.remoteHost, _, _ = net.SplitHostPort(r.RemoteAddr) } - // TODO: dispatchSubscriptionUpdate(subscriber) - h.dispatchSubscriptionUpdate(topics, encodedTopics, connectionID, claims, true, address) - if h.transport.AddSubscriber(subscriber) != nil { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - h.dispatchSubscriptionUpdate(topics, encodedTopics, connectionID, claims, false, address) - log.WithFields(fields).Error(err) - return nil, nil, false + h.dispatchSubscriptionUpdate(s, true) + if h.transport.AddSubscriber(s) != nil { + http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) + h.dispatchSubscriptionUpdate(s, false) + log.WithFields(s.logFields).Error(err) + return nil } sendHeaders(w) - log.WithFields(fields).Info("New subscriber") + log.WithFields(s.logFields).Info("New subscriber") - h.metrics.NewSubscriber(subscriber) + h.metrics.NewSubscriber(s) - unsubscribed := func() { - h.dispatchSubscriptionUpdate(topics, encodedTopics, connectionID, claims, false, address) - log.WithFields(fields).Info("Subscriber disconnected") - - h.metrics.SubscriberDisconnect(subscriber) - } - - return subscriber, unsubscribed, true + return s } func (h *Hub) parseTopics(topics []string) (rawTopics []string, templateTopics []*uritemplate.Template) { @@ -222,17 +216,22 @@ func (h *Hub) write(w io.Writer, r *http.Request, s *Subscriber, data string) bo case <-done: return true case <-time.After(d): - fields := createBaseLogFields(s.debug, r.RemoteAddr, nil, s) - log.WithFields(fields).Warn("Dispatch timeout reached") + log.WithFields(s.logFields).Warn("Dispatch timeout reached") return false } } -// cleanup removes unused uritemplate.Template instances from memory. -func (h *Hub) cleanup(s *Subscriber) { - keys := make([]string, 0, len(s.RawTopics)+len(s.TemplateTopics)) - copy(s.RawTopics, keys) - for _, uriTemplate := range s.TemplateTopics { +func (h *Hub) shutdown(s *Subscriber) { + // Notify that the client is closing the connection + s.Disconnect() + h.dispatchSubscriptionUpdate(s, false) + log.WithFields(s.logFields).Info("Subscriber disconnected") + h.metrics.SubscriberDisconnect(s) + + // Remove unused uritemplate.Template instances from memory. + keys := make([]string, 0, len(s.rawTopics)+len(s.templateTopics)) + copy(s.rawTopics, keys) + for _, uriTemplate := range s.templateTopics { keys = append(keys, uriTemplate.Raw()) } @@ -248,21 +247,21 @@ func (h *Hub) cleanup(s *Subscriber) { h.uriTemplates.Unlock() } -func (h *Hub) dispatchSubscriptionUpdate(topics, encodedTopics []string, connectionID string, claims *claims, active bool, address string) { +func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { if !h.config.GetBool("dispatch_subscriptions") { return } - for k, topic := range topics { + for k, topic := range s.topics { connection := &subscription{ - ID: "https://mercure.rocks/subscriptions/" + encodedTopics[k] + "/" + connectionID, + ID: "https://mercure.rocks/subscriptions/" + s.escapedTopics[k] + "/" + s.ID, Type: "https://mercure.rocks/Subscription", Topic: topic, Active: active, - Address: address, + Address: s.remoteHost, } - if claims == nil { + if s.claims == nil { connection.mercureClaim.Publish = []string{} connection.mercureClaim.Subscribe = []string{} } else { @@ -281,7 +280,7 @@ func (h *Hub) dispatchSubscriptionUpdate(topics, encodedTopics []string, connect u := &Update{ Topics: []string{connection.ID}, - Targets: map[string]struct{}{"https://mercure.rocks/targets/subscriptions": {}, "https://mercure.rocks/targets/subscriptions/" + encodedTopics[k]: {}}, + Targets: map[string]struct{}{"https://mercure.rocks/targets/subscriptions": {}, "https://mercure.rocks/targets/subscriptions/" + s.escapedTopics[k]: {}}, Event: Event{Data: string(json), ID: uuid.Must(uuid.NewV4()).String()}, } @@ -290,10 +289,10 @@ func (h *Hub) dispatchSubscriptionUpdate(topics, encodedTopics []string, connect } func escapeTopics(topics []string) []string { - encodedTopics := make([]string, 0, len(topics)) + escapedTopics := make([]string, 0, len(topics)) for _, topic := range topics { - encodedTopics = append(encodedTopics, url.QueryEscape(topic)) + escapedTopics = append(escapedTopics, url.QueryEscape(topic)) } - return encodedTopics + return escapedTopics } diff --git a/hub/subscribe_test.go b/hub/subscribe_test.go index a0c64fc8..590b86a7 100644 --- a/hub/subscribe_test.go +++ b/hub/subscribe_test.go @@ -154,23 +154,23 @@ func TestSubscribeNoTopic(t *testing.T) { assert.Equal(t, "Missing \"topic\" parameter.\n", w.Body.String()) } -type createPipeErrorTransport struct { +type addSubscriberErrorTransport struct { } -func (*createPipeErrorTransport) Write(update *Update) error { +func (*addSubscriberErrorTransport) Dispatch(*Update) error { return nil } -func (*createPipeErrorTransport) CreatePipe(fromID string) (*Pipe, error) { - return nil, fmt.Errorf("Failed to create a pipe") +func (*addSubscriberErrorTransport) AddSubscriber(*Subscriber) error { + return fmt.Errorf("Failed to create a pipe") } -func (*createPipeErrorTransport) Close() error { +func (*addSubscriberErrorTransport) Close() error { return nil } -func TestSubscribeCreatePipeError(t *testing.T) { - hub := createDummyWithTransportAndConfig(&createPipeErrorTransport{}, viper.New()) +func TestSubscribeAddSubscriberError(t *testing.T) { + hub := createDummyWithTransportAndConfig(&addSubscriberErrorTransport{}, viper.New()) req := httptest.NewRequest("GET", defaultHubURL+"?topic=foo", nil) w := httptest.NewRecorder() @@ -180,8 +180,8 @@ func TestSubscribeCreatePipeError(t *testing.T) { resp := w.Result() defer resp.Body.Close() - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - assert.Equal(t, http.StatusText(http.StatusInternalServerError)+"\n", w.Body.String()) + assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode) + assert.Equal(t, http.StatusText(http.StatusServiceUnavailable)+"\n", w.Body.String()) } func testSubscribe(numberOfSubscribers int, t *testing.T) { @@ -191,7 +191,7 @@ func testSubscribe(numberOfSubscribers int, t *testing.T) { for { s, _ := hub.transport.(*LocalTransport) s.RLock() - ready := len(s.pipes) == numberOfSubscribers + ready := len(s.subscribers) == numberOfSubscribers s.RUnlock() // There is a problem (probably related to Logrus?) preventing the benchmark to work without this line. @@ -200,23 +200,23 @@ func testSubscribe(numberOfSubscribers int, t *testing.T) { continue } - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"http://example.com/not-subscribed"}, Event: Event{Data: "Hello World", ID: "a"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"http://example.com/books/1"}, Event: Event{Data: "Hello World", ID: "b"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"http://example.com/reviews/22"}, Event: Event{Data: "Great", ID: "c"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"http://example.com/hub?topic=faulty{iri"}, Event: Event{Data: "Faulty IRI", ID: "d"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"string"}, Event: Event{Data: "string", ID: "e"}, }) @@ -256,7 +256,7 @@ func TestUnsubscribe(t *testing.T) { hub := createAnonymousDummy() s, _ := hub.transport.(*LocalTransport) - assert.Equal(t, 0, len(s.pipes)) + assert.Equal(t, 0, len(s.subscribers)) ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup @@ -265,15 +265,16 @@ func TestUnsubscribe(t *testing.T) { defer wg.Done() req := httptest.NewRequest("GET", defaultHubURL+"?topic=http://example.com/books/1", nil).WithContext(ctx) hub.SubscribeHandler(httptest.NewRecorder(), req) - assert.Equal(t, 1, len(s.pipes)) - for pipe := range s.pipes { - assert.True(t, pipe.IsClosed()) + assert.Equal(t, 1, len(s.subscribers)) + for s := range s.subscribers { + _, ok := <-s.disconnected + assert.False(t, ok) } }() for { s.RLock() - notEmpty := len(s.pipes) != 0 + notEmpty := len(s.subscribers) != 0 s.RUnlock() if notEmpty { break @@ -292,24 +293,24 @@ func TestSubscribeTarget(t *testing.T) { go func() { for { s.RLock() - empty := len(s.pipes) == 0 + empty := len(s.subscribers) == 0 s.RUnlock() if empty { continue } - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Targets: map[string]struct{}{"baz": {}}, Topics: []string{"http://example.com/reviews/21"}, Event: Event{Data: "Foo", ID: "a"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Targets: map[string]struct{}{}, Topics: []string{"http://example.com/reviews/22"}, Event: Event{Data: "Hello World", ID: "b", Type: "test"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Targets: map[string]struct{}{"hello": {}, "bar": {}}, Topics: []string{"http://example.com/reviews/23"}, Event: Event{Data: "Great", ID: "c", Retry: 1}, @@ -388,7 +389,7 @@ func TestSubscriptionEvents(t *testing.T) { s, _ := hub.transport.(*LocalTransport) for { s.RLock() - ready := len(s.pipes) == 2 + ready := len(s.subscribers) == 2 s.RUnlock() log.Info("Waiting for subscriber...") @@ -424,19 +425,19 @@ func TestSubscribeAllTargets(t *testing.T) { go func() { for { s.RLock() - empty := len(s.pipes) == 0 + empty := len(s.subscribers) == 0 s.RUnlock() if empty { continue } - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Targets: map[string]struct{}{"foo": {}}, Topics: []string{"http://example.com/reviews/21"}, Event: Event{Data: "Foo", ID: "a"}, }) - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Targets: map[string]struct{}{"bar": {}}, Topics: []string{"http://example.com/reviews/22"}, Event: Event{Data: "Hello World", ID: "b", Type: "test"}, @@ -463,20 +464,20 @@ func TestSubscribeAllTargets(t *testing.T) { func TestSendMissedEvents(t *testing.T) { u, _ := url.Parse("bolt://test.db") - transport, _ := NewBoltTransport(u, 5, time.Second) + transport, _ := NewBoltTransport(u) defer transport.Close() defer os.Remove("test.db") hub := createDummyWithTransportAndConfig(transport, viper.New()) - transport.Write(&Update{ + transport.Dispatch(&Update{ Topics: []string{"http://example.com/foos/a"}, Event: Event{ ID: "a", Data: "d1", }, }) - transport.Write(&Update{ + transport.Dispatch(&Update{ Topics: []string{"http://example.com/foos/b"}, Event: Event{ ID: "b", @@ -532,14 +533,14 @@ func TestSubscribeHeartbeat(t *testing.T) { go func() { for { s.RLock() - empty := len(s.pipes) == 0 + empty := len(s.subscribers) == 0 s.RUnlock() if empty { continue } - hub.transport.Write(&Update{ + hub.transport.Dispatch(&Update{ Topics: []string{"http://example.com/books/1"}, Event: Event{Data: "Hello World", ID: "b"}, }) diff --git a/hub/subscriber.go b/hub/subscriber.go index 6f87d52b..c0af1135 100644 --- a/hub/subscriber.go +++ b/hub/subscriber.go @@ -1,98 +1,83 @@ package hub import ( + "github.com/gofrs/uuid" log "github.com/sirupsen/logrus" "github.com/yosida95/uritemplate" ) -type updateSrc struct { +type updateSource struct { In chan *Update buffer []*Update } // Subscriber represents a client subscribed to a list of topics. type Subscriber struct { - AllTargets bool - Targets map[string]struct{} - Topics []string - RawTopics []string - TemplateTopics []*uritemplate.Template - LastEventID string - RemoteAddr string - - HistorySrc updateSrc - LiveSrc updateSrc - Out chan *Update - - ClientDisconnect chan struct{} - ServerDisconnect chan struct{} - - debug bool + ID string + History updateSource + Live updateSource + Out chan *Update + + disconnected chan struct{} + claims *claims + allTargets bool + targets map[string]struct{} + topics []string + escapedTopics []string + rawTopics []string + templateTopics []*uritemplate.Template + lastEventID string + remoteAddr string + remoteHost string + debug bool + + logFields log.Fields matchCache map[string]bool } -// NewSubscriber creates a subscriber. -func NewSubscriber(allTargets bool, targets map[string]struct{}, topics []string, rawTopics []string, templateTopics []*uritemplate.Template, lastEventID string, remoteAddr string, debug bool) *Subscriber { - s := &Subscriber{ - allTargets, - targets, - topics, - rawTopics, - templateTopics, - lastEventID, - remoteAddr, - - updateSrc{}, - updateSrc{In: make(chan *Update)}, - make(chan *Update), - - make(chan struct{}), - make(chan struct{}), - - debug, - make(map[string]bool), +func newSubscriber() *Subscriber { + id := uuid.Must(uuid.NewV4()).String() + return &Subscriber{ + ID: id, + History: updateSource{}, + Live: updateSource{In: make(chan *Update)}, + Out: make(chan *Update), + disconnected: make(chan struct{}), + logFields: log.Fields{"subscriber_id": id}, + matchCache: make(map[string]bool), } - - if lastEventID != "" { - s.HistorySrc.In = make(chan *Update) - } - go s.start() - - return s } func (s *Subscriber) start() { for { select { - case <-s.ClientDisconnect: + case <-s.disconnected: return - case <-s.ServerDisconnect: - return - case u, ok := <-s.HistorySrc.In: + case u, ok := <-s.History.In: if !ok { - s.HistorySrc.In = nil + s.History.In = nil break } if s.CanDispatch(u) { - s.HistorySrc.buffer = append(s.HistorySrc.buffer, u) + s.History.buffer = append(s.History.buffer, u) } - case u := <-s.LiveSrc.In: + case u := <-s.Live.In: if s.CanDispatch(u) { - s.LiveSrc.buffer = append(s.LiveSrc.buffer, u) + s.Live.buffer = append(s.Live.buffer, u) } case s.outChan() <- s.nextUpdate(): - if len(s.HistorySrc.buffer) > 0 { - s.HistorySrc.buffer = s.HistorySrc.buffer[1:] + if len(s.History.buffer) > 0 { + s.History.buffer = s.History.buffer[1:] break } - s.LiveSrc.buffer = s.LiveSrc.buffer[1:] + s.Live.buffer = s.Live.buffer[1:] } } } func (s *Subscriber) outChan() chan *Update { - if len(s.LiveSrc.buffer) > 0 || len(s.HistorySrc.buffer) > 0 { + if len(s.Live.buffer) > 0 || len(s.History.buffer) > 0 { return s.Out } return nil @@ -100,15 +85,15 @@ func (s *Subscriber) outChan() chan *Update { func (s *Subscriber) nextUpdate() *Update { // Always flush the history buffer first to preserve order - if s.HistorySrc.In != nil || len(s.HistorySrc.buffer) > 0 { - if len(s.HistorySrc.buffer) > 0 { - return s.HistorySrc.buffer[0] + if s.History.In != nil || len(s.History.buffer) > 0 { + if len(s.History.buffer) > 0 { + return s.History.buffer[0] } return nil } - if len(s.LiveSrc.buffer) > 0 { - return s.LiveSrc.buffer[0] + if len(s.Live.buffer) > 0 { + return s.Live.buffer[0] } return nil @@ -117,12 +102,12 @@ func (s *Subscriber) nextUpdate() *Update { // CanDispatch checks if an update can be dispatched to this subsriber. func (s *Subscriber) CanDispatch(u *Update) bool { if !s.IsAuthorized(u) { - log.WithFields(createBaseLogFields(s.debug, s.RemoteAddr, u, s)).Debug("Subscriber not authorized to receive this update (no targets matching)") + log.WithFields(createFields(u, s)).Debug("Subscriber not authorized to receive this update (no targets matching)") return false } if !s.IsSubscribed(u) { - log.WithFields(createBaseLogFields(s.debug, s.RemoteAddr, u, s)).Debug("Subscriber has not subscribed to this update (no topics matching)") + log.WithFields(createFields(u, s)).Debug("Subscriber has not subscribed to this update (no topics matching)") return false } @@ -132,11 +117,11 @@ func (s *Subscriber) CanDispatch(u *Update) bool { // IsAuthorized checks if the subscriber can access to at least one of the update's intended targets. // Don't forget to also call IsSubscribed. func (s *Subscriber) IsAuthorized(u *Update) bool { - if s.AllTargets || len(u.Targets) == 0 { + if s.allTargets || len(u.Targets) == 0 { return true } - for t := range s.Targets { + for t := range s.targets { if _, ok := u.Targets[t]; ok { return true } @@ -156,14 +141,14 @@ func (s *Subscriber) IsSubscribed(u *Update) bool { continue } - for _, rt := range s.RawTopics { + for _, rt := range s.rawTopics { if ut == rt { s.matchCache[ut] = true return true } } - for _, tt := range s.TemplateTopics { + for _, tt := range s.templateTopics { if tt.Match(ut) != nil { s.matchCache[ut] = true return true @@ -180,18 +165,26 @@ func (s *Subscriber) IsSubscribed(u *Update) bool { func (s *Subscriber) Dispatch(u *Update, fromHistory bool) bool { var in chan<- *Update if fromHistory { - in = s.HistorySrc.In + in = s.History.In } else { - in = s.LiveSrc.In + in = s.Live.In } select { - case <-s.ServerDisconnect: - return false - case <-s.ClientDisconnect: + case <-s.disconnected: return false case in <- u: } return true } + +func (s *Subscriber) Disconnect() { + select { + case <-s.disconnected: + return + default: + } + + close(s.disconnected) +} diff --git a/hub/subscriber_test.go b/hub/subscriber_test.go index 917af2dc..e2fdfb4c 100644 --- a/hub/subscriber_test.go +++ b/hub/subscriber_test.go @@ -7,7 +7,9 @@ import ( ) func TestIsSubscribed(t *testing.T) { - s := NewSubscriber(false, nil, []string{"foo", "bar"}, []string{"foo", "bar"}, nil, "lid") + s := newSubscriber() + s.topics = []string{"foo", "bar"} + s.rawTopics = s.topics assert.Len(t, s.matchCache, 0) assert.False(t, s.IsSubscribed(&Update{Topics: []string{"baz", "bat"}})) diff --git a/hub/transport.go b/hub/transport.go index 1b283e8c..a6afee07 100644 --- a/hub/transport.go +++ b/hub/transport.go @@ -108,7 +108,7 @@ func (t *LocalTransport) Close() error { t.RLock() defer t.RUnlock() for subscriber := range t.subscribers { - close(subscriber.ServerDisconnect) + subscriber.Disconnect() delete(t.subscribers, subscriber) } close(t.done) diff --git a/hub/transport_test.go b/hub/transport_test.go index a5c2c38c..1d608371 100644 --- a/hub/transport_test.go +++ b/hub/transport_test.go @@ -1,157 +1,151 @@ package hub import ( - "context" "os" "sync" "testing" - "time" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestLocalTransportWriteIsNotDispatchedUntilListen(t *testing.T) { - transport := NewLocalTransport(5, time.Second) +func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { + transport := NewLocalTransport() defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - err := transport.Write(&Update{}) - assert.Nil(t, err) + u := &Update{ + Topics: []string{"http://example.com/books/1"}, + } + err := transport.Dispatch(u) + require.Nil(t, err) - pipe, err := transport.CreatePipe("") - assert.Nil(t, err) - require.NotNil(t, pipe) + s := newSubscriber() + s.topics = u.Topics + s.rawTopics = u.Topics + s.targets = map[string]struct{}{"foo": {}} + go s.start() + + err = transport.AddSubscriber(s) + require.Nil(t, err) var ( + wg sync.WaitGroup readUpdate *Update ok bool - m sync.Mutex - wg sync.WaitGroup ) wg.Add(1) go func() { - m.Lock() - defer m.Unlock() - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - go wg.Done() - + defer wg.Done() select { - case readUpdate, ok = <-pipe.Read(): - case <-ctx.Done(): + case readUpdate = <-s.Out: + case <-s.disconnected: + ok = true } }() - wg.Wait() - pipe.Close() + s.Disconnect() - m.Lock() - defer m.Unlock() + wg.Wait() assert.Nil(t, readUpdate) - assert.False(t, ok) + assert.True(t, ok) } -func TestLocalTransportWriteIsDispatched(t *testing.T) { - transport := NewLocalTransport(5, time.Second) +func TestLocalTransportDispatch(t *testing.T) { + transport := NewLocalTransport() defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - pipe, err := transport.CreatePipe("") - assert.Nil(t, err) - require.NotNil(t, pipe) - defer pipe.Close() + s := newSubscriber() + s.topics = []string{"http://example.com/foo"} + s.rawTopics = s.topics + go s.start() - var ( - readUpdate *Update - ok bool - m sync.Mutex - wg sync.WaitGroup - ) - wg.Add(1) - go func() { - m.Lock() - defer m.Unlock() + err := transport.AddSubscriber(s) + assert.Nil(t, err) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) - defer cancel() - go wg.Done() - select { - case readUpdate, ok = <-pipe.Read(): - case <-ctx.Done(): - } - }() + u := &Update{Topics: s.topics} - wg.Wait() - err = transport.Write(&Update{}) + err = transport.Dispatch(u) assert.Nil(t, err) - m.Lock() - defer m.Unlock() - - assert.True(t, ok) - assert.NotNil(t, readUpdate) + readUpdate := <-s.Out + assert.Equal(t, u, readUpdate) } func TestLocalTransportClosed(t *testing.T) { - transport := NewLocalTransport(5, time.Second) + transport := NewLocalTransport() defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - pipe, _ := transport.CreatePipe("") - require.NotNil(t, pipe) + s := newSubscriber() + err := transport.AddSubscriber(s) + require.Nil(t, err) - err := transport.Close() + err = transport.Close() assert.Nil(t, err) - _, err = transport.CreatePipe("") + err = transport.AddSubscriber(newSubscriber()) assert.Equal(t, err, ErrClosedTransport) - err = transport.Write(&Update{}) + err = transport.Dispatch(&Update{}) assert.Equal(t, err, ErrClosedTransport) - _, ok := <-pipe.Read() + _, ok := <-s.disconnected assert.False(t, ok) } -func TestLiveCleanClosedPipes(t *testing.T) { - transport := NewLocalTransport(5, time.Second) +func TestLiveCleanDisconnectedSubscribers(t *testing.T) { + transport := NewLocalTransport() defer transport.Close() - pipe, _ := transport.CreatePipe("") - require.NotNil(t, pipe) + s1 := newSubscriber() + go s1.start() + + err := transport.AddSubscriber(s1) + require.Nil(t, err) + + s2 := newSubscriber() + go s2.start() + + err = transport.AddSubscriber(s2) + require.Nil(t, err) - assert.Len(t, transport.pipes, 1) + assert.Len(t, transport.subscribers, 2) - pipe.Close() - assert.Len(t, transport.pipes, 1) + s1.Disconnect() + assert.Len(t, transport.subscribers, 2) - transport.Write(&Update{}) - assert.Len(t, transport.pipes, 0) + transport.Dispatch(&Update{Topics: s1.topics}) + assert.Len(t, transport.subscribers, 1) + + s2.Disconnect() + assert.Len(t, transport.subscribers, 1) + + transport.Dispatch(&Update{}) + assert.Len(t, transport.subscribers, 0) } -func TestLivePipeReadingBlocks(t *testing.T) { - transport := NewLocalTransport(5, time.Second) +func TestLiveReading(t *testing.T) { + transport := NewLocalTransport() defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - pipe, err := transport.CreatePipe("") + s := newSubscriber() + s.topics = []string{"https://example.com"} + s.rawTopics = s.topics + go s.start() + + err := transport.AddSubscriber(s) assert.Nil(t, err) - require.NotNil(t, pipe) - var wg sync.WaitGroup - wg.Add(1) - go func() { - wg.Wait() - err := transport.Write(&Update{}) - assert.Nil(t, err) - }() - wg.Done() - u, ok := <-pipe.Read() - assert.True(t, ok) - assert.NotNil(t, u) + u := &Update{Topics: s.topics} + err = transport.Dispatch(u) + assert.Nil(t, err) + + receivedUpdate := <-s.Out + assert.Equal(t, u, receivedUpdate) } func TestNewTransport(t *testing.T) {