diff --git a/stores/common.go b/stores/common.go index 616f4f48..c460bb03 100644 --- a/stores/common.go +++ b/stores/common.go @@ -15,6 +15,7 @@ package stores import ( "sync" + "time" "github.com/nats-io/nats-streaming-server/logger" "github.com/nats-io/nats-streaming-server/spb" @@ -337,6 +338,17 @@ func (gms *genericMsgStore) Close() error { return nil } +// With the given timestamp, returns in how long the message +// should expire. If in the past, returns 0 +func (gms *genericMsgStore) msgExpireIn(timestamp int64) time.Duration { + now := time.Now().UnixNano() + fireIn := time.Duration(timestamp + int64(gms.limits.MaxAge) - now) + if fireIn < 0 { + fireIn = 0 + } + return fireIn +} + //////////////////////////////////////////////////////////////////////////// // genericSubStore methods //////////////////////////////////////////////////////////////////////////// diff --git a/stores/common_msg_test.go b/stores/common_msg_test.go index 05d1fdc5..636ef8b3 100644 --- a/stores/common_msg_test.go +++ b/stores/common_msg_test.go @@ -495,6 +495,42 @@ func TestCSMaxAgeWithGapInSeq(t *testing.T) { } } +func TestCSMaxAgeForMsgsWithTimestampInPast(t *testing.T) { + for _, st := range testStores { + st := st + t.Run(st.name, func(t *testing.T) { + t.Parallel() + defer endTest(t, st) + s := startTest(t, st) + defer s.Close() + + sl := testDefaultStoreLimits + sl.MaxAge = time.Minute + s.SetLimits(&sl) + + cs := storeCreateChannel(t, s, "foo") + for seq := uint64(1); seq < 3; seq++ { + // Create a message with a timestamp in the past. + msg := &pb.MsgProto{ + Sequence: seq, + Subject: "foo", + Data: []byte("hello"), + Timestamp: time.Now().Add(-time.Hour).UnixNano(), + } + if _, err := cs.Msgs.Store(msg); err != nil { + t.Fatalf("Error storing message: %v", err) + } + // Wait a bit + time.Sleep(300 * time.Millisecond) + // Check that message has expired. + if first, err := cs.Msgs.FirstSequence(); err != nil || first != seq+1 { + t.Fatal("Message should have expired") + } + } + }) + } +} + func TestCSGetSeqFromStartTime(t *testing.T) { for _, st := range testStores { st := st diff --git a/stores/memstore.go b/stores/memstore.go index 650a9926..51bda5be 100644 --- a/stores/memstore.go +++ b/stores/memstore.go @@ -107,7 +107,7 @@ func (ms *MemoryMsgStore) Store(m *pb.MsgProto) (uint64, error) { // If there is an age limit and no timer yet created, do so now if ms.limits.MaxAge > time.Duration(0) && ms.ageTimer == nil { ms.wg.Add(1) - ms.ageTimer = time.AfterFunc(ms.limits.MaxAge, ms.expireMsgs) + ms.ageTimer = time.AfterFunc(ms.msgExpireIn(m.Timestamp), ms.expireMsgs) } // Check if we need to remove any (but leave at least the last added) diff --git a/stores/sqlstore.go b/stores/sqlstore.go index 24a1081b..809d748c 100644 --- a/stores/sqlstore.go +++ b/stores/sqlstore.go @@ -331,6 +331,7 @@ type SQLMsgStore struct { channelID int64 sqlStore *SQLStore // Reference to "parent" store expireTimer *time.Timer + fTimestamp int64 wg sync.WaitGroup // If option NoBuffering is false, uses this cache for storing Store() @@ -1407,8 +1408,9 @@ func (ms *SQLMsgStore) Store(m *pb.MsgProto) (uint64, error) { return 0, sqlStmtError(sqlStoreMsg, err) } } - if ms.first == 0 { + if ms.first == 0 || ms.first == seq { ms.first = seq + ms.fTimestamp = m.Timestamp } ms.last = seq ms.totalCount++ @@ -1461,7 +1463,7 @@ func (ms *SQLMsgStore) Store(m *pb.MsgProto) (uint64, error) { func (ms *SQLMsgStore) createExpireTimer() { ms.wg.Add(1) - ms.expireTimer = time.AfterFunc(ms.limits.MaxAge, ms.expireMsgs) + ms.expireTimer = time.AfterFunc(ms.msgExpireIn(ms.fTimestamp), ms.expireMsgs) } // Lookup implements the MsgStore interface @@ -1563,7 +1565,6 @@ func (ms *SQLMsgStore) expireMsgs() { count int maxSeq uint64 totalSize uint64 - timestamp int64 ) processErr := func(errCode int, err error) { ms.log.Errorf("Unable to perform expiration for channel %q: %v", ms.subject, sqlStmtError(errCode, err)) @@ -1595,24 +1596,24 @@ func (ms *SQLMsgStore) expireMsgs() { ms.totalBytes -= totalSize } // Reset since we are in a loop - timestamp = 0 + ms.fTimestamp = 0 // If there is any message left in the channel, find out what the expiration // timer needs to be set to. if ms.totalCount > 0 { r = ms.sqlStore.preparedStmts[sqlGetFirstMsgTimestamp].QueryRow(ms.channelID, ms.first) - if err := r.Scan(×tamp); err != nil { + if err := r.Scan(&ms.fTimestamp); err != nil { processErr(sqlGetFirstMsgTimestamp, err) return } } // No message left or no message to expire. The timer will be recreated when // a new message is added to the channel. - if timestamp == 0 { - ms.wg.Done() + if ms.fTimestamp == 0 { ms.expireTimer = nil + ms.wg.Done() return } - elapsed := time.Duration(time.Now().UnixNano() - timestamp) + elapsed := time.Duration(time.Now().UnixNano() - ms.fTimestamp) if elapsed < ms.limits.MaxAge { ms.expireTimer.Reset(ms.limits.MaxAge - elapsed) // Done with the for loop diff --git a/stores/sqlstore_test.go b/stores/sqlstore_test.go index 7ea2850b..214301db 100644 --- a/stores/sqlstore_test.go +++ b/stores/sqlstore_test.go @@ -2095,3 +2095,48 @@ func TestSQLDeadlines(t *testing.T) { return err }) } + +func TestSQLMaxAgeForMsgsWithTimestampInPast(t *testing.T) { + if !doSQL { + t.SkipNow() + } + + cleanupSQLDatastore(t) + defer cleanupSQLDatastore(t) + + // Create store with caching enabled (the no cache is handled in + // test TestCSMaxAgeForMsgsWithTimestampInPast). + s, err := NewSQLStore(testLogger, testSQLDriver, testSQLSource, nil, SQLNoCaching(false)) + if err != nil { + t.Fatalf("Error creating store: %v", err) + } + defer s.Close() + + sl := testDefaultStoreLimits + sl.MaxAge = time.Minute + s.SetLimits(&sl) + + cs := storeCreateChannel(t, s, "foo") + for seq := uint64(1); seq < 3; seq++ { + // Create a message with a timestamp in the past. + msg := &pb.MsgProto{ + Sequence: seq, + Subject: "foo", + Data: []byte("hello"), + Timestamp: time.Now().Add(-time.Hour).UnixNano(), + } + if _, err := cs.Msgs.Store(msg); err != nil { + t.Fatalf("Error storing message: %v", err) + } + // With caching, timer is triggered on Flush(). + if err := cs.Msgs.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + // Wait a bit + time.Sleep(300 * time.Millisecond) + // Check that message has expired. + if first, err := cs.Msgs.FirstSequence(); err != nil || first != seq+1 { + t.Fatal("Message should have expired") + } + } +}