diff --git a/server/clustering_test.go b/server/clustering_test.go index 43923c0b..5aee5e8b 100644 --- a/server/clustering_test.go +++ b/server/clustering_test.go @@ -6627,11 +6627,12 @@ type blockingLookupStore struct { } func (b *blockingLookupStore) Lookup(seq uint64) (*pb.MsgProto, error) { + msg, err := b.MsgStore.Lookup(seq) if !b.skip { b.inLookupCh <- struct{}{} b.skip = <-b.releaseCh } - return b.MsgStore.Lookup(seq) + return msg, err } func TestClusteringRestoreSnapshotErrorDontSkipSeq(t *testing.T) { diff --git a/server/server.go b/server/server.go index d17235db..eea59aae 100644 --- a/server/server.go +++ b/server/server.go @@ -5169,7 +5169,7 @@ func (s *StanServer) sendAvailableMessages(c *channel, sub *subState) { } func (s *StanServer) getNextMsg(c *channel, nextSeq, lastSent *uint64) *pb.MsgProto { - for { + for i := 0; ; i++ { nextMsg, err := c.store.Msgs.Lookup(*nextSeq) if err != nil { s.log.Errorf("Error looking up message %v:%v (%v)", c.name, *nextSeq, err) @@ -5180,13 +5180,33 @@ func (s *StanServer) getNextMsg(c *channel, nextSeq, lastSent *uint64) *pb.MsgPr if nextMsg != nil { return nextMsg } + // Message was not found, check the store first/last sequences. first, last, _ := c.store.Msgs.FirstAndLastSequence() - if *nextSeq < first { + if *nextSeq >= last { + // This means that we are looking for a message that has not + // been stored. This is perfectly normal when delivering messages + // and reach the end of the channel. + return nil + } else if *nextSeq < first { + // We were trying to lookup a message that has likely now + // been removed (expired, or due to max msgs/bytes etc) since + // the first available is greater than the message we were + // looking for. Try to lookup the first available. *nextSeq = first *lastSent = first - 1 - } else if *nextSeq >= last { - return nil - } else { + } else if i > 0 { + // The last condition is when `nextSeq` is between `first` and + // `last`, which could mean 2 things: + // - New messages have been stored between the lookup at the top + // of the loop and calling FirstAndLastSequence(), so we should + // try again. + // - There is a gap - which should not happen but we have decided + // to support this situation - so we move by one at a time until + // we find a valid message. + + // So if i==0 (first iteration) we don't come here and simply try + // again. Otherwise, move the requested sequence in search of the + // first valid message. *nextSeq++ *lastSent++ } diff --git a/server/server_delivery_test.go b/server/server_delivery_test.go index adb2fc81..50f19cde 100644 --- a/server/server_delivery_test.go +++ b/server/server_delivery_test.go @@ -15,6 +15,7 @@ package server import ( "fmt" + "sync" "sync/atomic" "testing" "time" @@ -324,3 +325,60 @@ func TestPersistentStoreSQLSubsPendingRows(t *testing.T) { } waitForAcks(t, s, clientName, 1, 3002) } + +func TestDeliveryRaceBetweenNextMsgAndStoring(t *testing.T) { + s := runServer(t, clusterName) + defer s.Shutdown() + + sc := NewDefaultConnection(t) + defer sc.Close() + + prev := uint64(0) + errCh := make(chan error, 1) + doneCh := make(chan bool) + cb := func(m *stan.Msg) { + if m.Sequence != prev+1 { + errCh <- fmt.Errorf("Previous was %v, now got %v", prev, m.Sequence) + m.Sub.Close() + return + } + prev = m.Sequence + if m.Sequence == 4 { + doneCh <- true + } + } + if _, err := sc.Subscribe("foo", cb, stan.MaxInflight(1)); err != nil { + t.Fatalf("Erro on subscribe: %v", err) + } + + sc.Publish("foo", []byte("msg1")) + + c := s.channels.get("foo") + ch1 := make(chan struct{}) + ch2 := make(chan bool) + c.store.Msgs = &blockingLookupStore{MsgStore: c.store.Msgs, inLookupCh: ch1, releaseCh: ch2} + + sub := s.clients.getSubs(clientName)[0] + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + s.sendAvailableMessages(c, sub) + wg.Done() + }() + <-ch1 + sc.PublishAsync("foo", []byte("msg2"), nil) + sc.PublishAsync("foo", []byte("msg3"), nil) + time.Sleep(50 * time.Millisecond) + ch2 <- true + wg.Wait() + + sc.Publish("foo", []byte("msg4")) + + select { + case <-doneCh: + case e := <-errCh: + t.Fatal(e.Error()) + case <-time.After(time.Second): + t.Fatal("Timeout!") + } +}