Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream interface + satisfy consumer interface #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ import (
"github.com/redis/go-redis/v9"
)

// Consumer is a generic consumer interface
type Consumer[T any] interface {
Chan() <-chan Message[T]
Close()
}
var _ Consumer[any] = (*StreamConsumer[any])(nil)

type StreamIDs = map[string]string

Expand Down Expand Up @@ -85,7 +81,7 @@ func (sc *StreamConsumer[T]) Chan() <-chan Message[T] {
//
// The StreamIds can be used to construct a new StreamConsumer that will
// pick up where this left off.
func (sc *StreamConsumer[T]) Close() StreamIDs {
func (sc *StreamConsumer[T]) Close() any {
select {
case <-sc.ctx.Done():
default:
Expand Down
4 changes: 2 additions & 2 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func TestConsumer_CloseGetSeenIDs(t *testing.T) {
<-cs.Chan()
}

seen := cs.Close()
seen := cs.Close().(StreamIDs)
assert.Equal(t, fmt.Sprintf("0-%v", consumeCount), seen["s1"])
}

Expand Down Expand Up @@ -172,7 +172,7 @@ func TestConsumer_CancelContext(t *testing.T) {
}
}

seen := cs.Close()
seen := cs.Close().(StreamIDs)
assert.Equal(t, fmt.Sprintf("0-%v", consumeCount), seen["s1"])
}

Expand Down
4 changes: 3 additions & 1 deletion group_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/redis/go-redis/v9"
)

var _ (Consumer[any]) = (*GroupConsumer[any])(nil)

// ErrAckBadRetVal is caused by XACK not accepting an request by returning 0.
// This usually indicates that the id is wrong or the stream has no groups.
var ErrAckBadRetVal = errors.New("XAck made no acknowledgement")
Expand Down Expand Up @@ -142,7 +144,7 @@ func (gc *GroupConsumer[T]) AwaitAcks() []Message[T] {
// Close closes the consumer (if not already closed) and returns
// a slice of unprocessed ack requests. An ack request in unprocessed if it
// wasn't sent or its error wasn't consumed.
func (gc *GroupConsumer[T]) Close() []InnerAck {
func (gc *GroupConsumer[T]) Close() any {
select {
case <-gc.ctx.Done():
default:
Expand Down
4 changes: 2 additions & 2 deletions group_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestGroupConsumer_AckErrors(t *testing.T) {
}

lastErrs := cs.AwaitAcks()
unseen := cs.Close()
unseen := cs.Close().([]InnerAck)
assert.NotZero(t, len(lastErrs)+len(unseen))
assert.Equal(t, readCount, ackErrors+len(unseen)+len(lastErrs))
}
Expand Down Expand Up @@ -397,5 +397,5 @@ func TestGroupConsumer_ConcurrentRead(t *testing.T) {
assert.Greater(t, len(msgError), 1)
assert.Greater(t, len(msg), 1)
assert.Equal(t, len(msg)+len(msgError), 15)
assert.Equal(t, len(cs.Close())+len(msgList)+len(msg)+len(msgError), 101)
assert.Equal(t, len(cs.Close().([]InnerAck))+len(msgList)+len(msg)+len(msgError), 101)
}
18 changes: 18 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package gtrs

import "context"

// Stream represents a redis stream with messages of type T.
type Stream[T any] interface {
Add(ctx context.Context, v T, idarg ...string) (string, error)
Key() string
Len(ctx context.Context) (int64, error)
Range(ctx context.Context, from, to string, count ...int64) ([]Message[T], error)
RevRange(ctx context.Context, from, to string, count ...int64) ([]Message[T], error)
}

// Consumer is a generic consumer interface
type Consumer[T any] interface {
Chan() <-chan Message[T]
Close() any
}
18 changes: 10 additions & 8 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ var NoMaxLen = int64(0)
// now is defined here so it can be overridden in unit tests
var now = time.Now

var _ Stream[any] = (*RedisStream[any])(nil)

// Stream represents a redis stream with messages of type T.
type Stream[T any] struct {
type RedisStream[T any] struct {
client redis.Cmdable
stream string
ttl time.Duration
Expand All @@ -38,7 +40,7 @@ type Options struct {

// NewStream create a new stream with messages of type T.
// Options are optional (the parameter can be nil to use defaults).
func NewStream[T any](client redis.Cmdable, stream string, opt *Options) Stream[T] {
func NewStream[T any](client redis.Cmdable, stream string, opt *Options) RedisStream[T] {
var approx bool
maxLen := NoMaxLen
ttl := NoExpiration
Expand All @@ -47,16 +49,16 @@ func NewStream[T any](client redis.Cmdable, stream string, opt *Options) Stream[
maxLen = opt.MaxLen
approx = opt.Approx
}
return Stream[T]{client: client, stream: stream, ttl: ttl, maxLen: maxLen, approx: approx}
return RedisStream[T]{client: client, stream: stream, ttl: ttl, maxLen: maxLen, approx: approx}
}

// Key returns the redis stream key.
func (s Stream[T]) Key() string {
func (s RedisStream[T]) Key() string {
return s.stream
}

// Add a message to the stream. Calls XADD.
func (s Stream[T]) Add(ctx context.Context, v T, idarg ...string) (string, error) {
func (s RedisStream[T]) Add(ctx context.Context, v T, idarg ...string) (string, error) {
id := ""
if len(idarg) > 0 {
id = idarg[0]
Expand Down Expand Up @@ -92,7 +94,7 @@ func (s Stream[T]) Add(ctx context.Context, v T, idarg ...string) (string, error
}

// Range returns a portion of the stream. Calls XRANGE.
func (s Stream[T]) Range(ctx context.Context, from, to string, count ...int64) ([]Message[T], error) {
func (s RedisStream[T]) Range(ctx context.Context, from, to string, count ...int64) ([]Message[T], error) {
var redisSlice []redis.XMessage
var err error
if len(count) == 0 {
Expand All @@ -113,7 +115,7 @@ func (s Stream[T]) Range(ctx context.Context, from, to string, count ...int64) (
}

// RevRange returns a portion of the stream in reverse order compared to Range. Calls XREVRANGE.
func (s Stream[T]) RevRange(ctx context.Context, from, to string, count ...int64) ([]Message[T], error) {
func (s RedisStream[T]) RevRange(ctx context.Context, from, to string, count ...int64) ([]Message[T], error) {
var redisSlice []redis.XMessage
var err error
if len(count) == 0 {
Expand All @@ -134,7 +136,7 @@ func (s Stream[T]) RevRange(ctx context.Context, from, to string, count ...int64
}

// Len returns the current stream length. Calls XLEN.
func (s Stream[T]) Len(ctx context.Context) (int64, error) {
func (s RedisStream[T]) Len(ctx context.Context) (int64, error) {
len, err := s.client.XLen(ctx, s.stream).Result()
if err != nil {
err = ReadError{Err: err}
Expand Down