Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix nil pointer dereference on AWS errors #148

Merged
merged 2 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions allgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type AllGroup struct {

// Start is a blocking operation which will loop and attempt to find new
// shards on a regular cadence.
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) error {
// Note: while ticker is a rather naive approach to this problem,
// it actually simplifies a few things. i.e. If we miss a new shard
// while AWS is resharding we'll pick it up max 30 seconds later.
Expand All @@ -48,12 +48,16 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
var ticker = time.NewTicker(30 * time.Second)

for {
g.findNewShards(ctx, shardc)
err := g.findNewShards(ctx, shardc)
if err != nil {
ticker.Stop()
return err
}

select {
case <-ctx.Done():
ticker.Stop()
return
return nil
case <-ticker.C:
}
}
Expand All @@ -62,7 +66,7 @@ func (g *AllGroup) Start(ctx context.Context, shardc chan types.Shard) {
// findNewShards pulls the list of shards from the Kinesis API
// and uses a local cache to determine if we are already processing
// a particular shard.
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) error {
g.shardMu.Lock()
defer g.shardMu.Unlock()

Expand All @@ -71,7 +75,7 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
shards, err := listShards(ctx, g.ksis, g.streamName)
if err != nil {
g.logger.Log("[GROUP] error:", err)
return
return err
}

for _, shard := range shards {
Expand All @@ -81,4 +85,5 @@ func (g *AllGroup) findNewShards(ctx context.Context, shardc chan types.Shard) {
g.shards[*shard.ShardId] = shard
shardc <- shard
}
return nil
}
11 changes: 9 additions & 2 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ func (c *Consumer) Scan(ctx context.Context, fn ScanFunc) error {
)

go func() {
c.group.Start(ctx, shardc)
err := c.group.Start(ctx, shardc)
if err != nil {
errc <- fmt.Errorf("error starting scan: %w", err)
cancel()
}
<-ctx.Done()
close(shardc)
}()
Expand Down Expand Up @@ -276,7 +280,10 @@ func (c *Consumer) getShardIterator(ctx context.Context, streamName, shardID, se
}

res, err := c.client.GetShardIterator(ctx, params)
return res.ShardIterator, err
if err != nil {
return nil, err
}
return res.ShardIterator, nil
}

func isRetriableError(err error) bool {
Expand Down
67 changes: 67 additions & 0 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package consumer

import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kinesis"
Expand Down Expand Up @@ -98,6 +100,71 @@ func TestScan(t *testing.T) {
}
}

func TestScan_ListShardsError(t *testing.T) {
mockError := errors.New("mock list shards error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return nil, mockError
},
}

// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)

var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}

c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}

err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from listShards, but instead got %v", err)
}
}

func TestScan_GetShardIteratorError(t *testing.T) {
mockError := errors.New("mock get shard iterator error")
client := &kinesisClientMock{
listShardsMock: func(ctx context.Context, params *kinesis.ListShardsInput, optFns ...func(*kinesis.Options)) (*kinesis.ListShardsOutput, error) {
return &kinesis.ListShardsOutput{
Shards: []types.Shard{
{ShardId: aws.String("myShard")},
},
}, nil
},
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
return nil, mockError
},
}

// use cancel func to signal shutdown
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)

var res string
var fn = func(r *Record) error {
res += string(r.Data)
cancel() // simulate cancellation while processing first record
return nil
}

c, err := New("myStreamName", WithClient(client))
if err != nil {
t.Fatalf("new consumer error: %v", err)
}

err = c.Scan(ctx, fn)
if !errors.Is(err, mockError) {
t.Errorf("expected an error from getShardIterator, but instead got %v", err)
}
}

func TestScanShard(t *testing.T) {
var client = &kinesisClientMock{
getShardIteratorMock: func(ctx context.Context, params *kinesis.GetShardIteratorInput, optFns ...func(*kinesis.Options)) (*kinesis.GetShardIteratorOutput, error) {
Expand Down
2 changes: 1 addition & 1 deletion group.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

// Group interface used to manage which shard to process
type Group interface {
Start(ctx context.Context, shardc chan types.Shard)
Start(ctx context.Context, shardc chan types.Shard) error
GetCheckpoint(streamName, shardID string) (string, error)
SetCheckpoint(streamName, shardID, sequenceNumber string) error
}