Skip to content

Commit

Permalink
Fix #174 and Migrate to AWS sdk v2
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagrulla committed Mar 5, 2021
1 parent 021fd56 commit 53e878d
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 329 deletions.
63 changes: 30 additions & 33 deletions cloudwatch/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,18 @@
package cloudwatch

import (
"context"
"fmt"
"log"
"os"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
)

type cwl interface {
Tail(cwl cloudwatchlogsiface.CloudWatchLogsAPI,
logGroupName *string, logStreamName *string, follow *bool, retry *bool,
startTime *time.Time, endTime *time.Time,
grep *string, grepv *string,
limiter <-chan time.Time, log *log.Logger) <-chan *cloudwatchlogs.FilteredLogEvent
LsStreams(cwl cloudwatchlogsiface.CloudWatchLogsAPI, groupName *string, streamName *string) <-chan *string
}

// New creates a new instance of the cloudwatchlogs client
func New(awsEndpointURL *string, awsProfile *string, awsRegion *string, log *log.Logger) *cloudwatchlogs.CloudWatchLogs {
func New(awsEndpointURL *string, awsProfile *string, awsRegion *string, log *log.Logger) *cloudwatchlogs.Client {
//workaround to figure out the user actual home dir within a SNAP (rather than the sandboxed one)
//and access the .aws folder in its default location
if os.Getenv("SNAP_INSTANCE_NAME") != "" {
Expand All @@ -40,29 +30,36 @@ func New(awsEndpointURL *string, awsProfile *string, awsRegion *string, log *log
os.Setenv("AWS_CONFIG_FILE", configPath)
}
}
log.Printf("awsProfile: %s, awsRegion: %s\n", *awsProfile, *awsRegion)

if awsEndpointURL != nil {
log.Printf("awsEndpointURL:%s", *awsEndpointURL)
profile := ""
region := ""
if awsProfile != nil && *awsProfile != "" {
profile = *awsProfile
}
opts := session.Options{
SharedConfigState: session.SharedConfigEnable,
if awsRegion != nil && *awsRegion != "" {
region = *awsRegion
}

if awsProfile != nil {
opts.Profile = *awsProfile
}
log.Printf("awsProfile: %s, awsRegion: %s endpoint: %s\n", profile, region, *awsEndpointURL)

cfg := aws.Config{}
customResolver := aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) {
if awsEndpointURL != nil && *awsEndpointURL != "" {
log.Printf("awsEndpointURL:%s", *awsEndpointURL)
return aws.Endpoint{
PartitionID: "aws",
URL: *awsEndpointURL,
SigningRegion: region,
SigningName: "logs",
}, nil
}
// returning EndpointNotFoundError will allow the service to fallback to it's default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})

if awsEndpointURL != nil {
cfg.Endpoint = awsEndpointURL
}
if awsRegion != nil {
cfg.Region = awsRegion
cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithSharedConfigProfile(profile),
config.WithEndpointResolver(customResolver), config.WithRegion(region))
if err != nil {
os.Exit(1)
}

opts.Config = cfg
sess := session.Must(session.NewSessionWithOptions(opts))
return cloudwatchlogs.New(sess)
return cloudwatchlogs.NewFromConfig(cfg)
}
192 changes: 116 additions & 76 deletions cloudwatch/cloudwatchlogs_test.go
Original file line number Diff line number Diff line change
@@ -1,115 +1,155 @@
package cloudwatch

import (
"io/ioutil"
"context"
"fmt"
"log"
"os"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
"github.com/stretchr/testify/assert"
)

type mockCloudWatchLogsClient struct {
cloudwatchlogsiface.CloudWatchLogsAPI
streams []string
}

type mockCloudWatchLogsClientRetry struct {
cloudwatchlogsiface.CloudWatchLogsAPI
streams []string
}

var (
streams = []string{"a", "b"}
logger = log.New(ioutil.Discard, "", log.LstdFlags)
streams = []types.LogStream{
{LogStreamName: aws.String("stream1"), LastIngestionTime: aws.Int64(time.Now().Unix())},
{LogStreamName: aws.String("stream2"), LastIngestionTime: aws.Int64(time.Now().AddDate(1, 0, 0).Unix())}}
)

func (m *mockCloudWatchLogsClient) DescribeLogStreamsPages(input *cloudwatchlogs.DescribeLogStreamsInput,
fn func(*cloudwatchlogs.DescribeLogStreamsOutput, bool) bool) error {
// s := []*cloudwatchlogs.LogStream{}
// for _, t := range m.streams {
// s = append(s, &cloudwatchlogs.LogStream{LogStreamName: aws.String(t)})
// }
// o := &cloudwatchlogs.DescribeLogStreamsOutput{LogStreams: s}
// fn(o, true)
return awserr.New("ResourceNotFoundException", "", nil)
type MockPager struct {
PageNum int
Pages []*cloudwatchlogs.DescribeLogStreamsOutput
err error
}

type mockCloudWatchLogsClientLsStreams struct {
cloudwatchlogsiface.CloudWatchLogsAPI
streams []string
func (m *MockPager) HasMorePages() bool {
return m.PageNum < len(m.Pages)
}

func (m *mockCloudWatchLogsClientLsStreams) DescribeLogStreamsPages(input *cloudwatchlogs.DescribeLogStreamsInput,
fn func(*cloudwatchlogs.DescribeLogStreamsOutput, bool) bool) error {
s := []*cloudwatchlogs.LogStream{}
for _, t := range m.streams {
s = append(s, &cloudwatchlogs.LogStream{LogStreamName: aws.String(t)})
func (m *MockPager) NextPage(ctx context.Context, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.DescribeLogStreamsOutput, error) {
if m.err != nil {
return nil, m.err
}
if m.PageNum >= len(m.Pages) {
return nil, fmt.Errorf("no more pages")
}
o := &cloudwatchlogs.DescribeLogStreamsOutput{LogStreams: s}
fn(o, true)
return nil
output := m.Pages[m.PageNum]
m.PageNum++
return output, nil
}

func TestLsStreams(t *testing.T) {
mockSvc := &mockCloudWatchLogsClientLsStreams{
streams: streams,
pag := &MockPager{PageNum: 0,
Pages: []*cloudwatchlogs.DescribeLogStreamsOutput{{LogStreams: streams}},
}
ch, _ := LsStreams(mockSvc, aws.String("a"), aws.String("b"))
ch := make(chan types.LogStream)
errCh := make(chan error)
go getStreams(pag, errCh, ch)

for l := range ch {
assert.Contains(t, streams, *l)
assert.Contains(t, streams, l)
}
}

func TestTailShouldFailIfNoStreamsAdNoRetry(t *testing.T) {
mockSvc := &mockCloudWatchLogsClient{}
mockSvc.streams = []string{}
idleCh := make(chan bool)

n := time.Now()
trigger := time.NewTicker(100 * time.Millisecond).C
fetchStreams := func() (<-chan types.LogStream, <-chan error) {
ch := make(chan types.LogStream)
errCh := make(chan error, 1)
rnf := &types.ResourceNotFoundException{
Message: new(string),
}
errCh <- rnf
return ch, errCh
}
retry := false
err := initialiseStreams(&retry, idleCh, nil, fetchStreams)

ch, e := Tail(mockSvc, aws.String("logGroup"), aws.String("logStreamName"), aws.Bool(false), aws.Bool(false),
&n, &n, aws.String(""), aws.String(""),
trigger, logger)
assert.Error(t, e)
assert.Nil(t, ch)
assert.Error(t, err)
}

var cnt = 0

func (m *mockCloudWatchLogsClientRetry) DescribeLogStreamsPages(input *cloudwatchlogs.DescribeLogStreamsInput,
fn func(*cloudwatchlogs.DescribeLogStreamsOutput, bool) bool) error {
s := []*cloudwatchlogs.LogStream{}
if cnt != 0 {
for _, t := range m.streams {
s = append(s, &cloudwatchlogs.LogStream{LogStreamName: aws.String(t)})
func TestTailWaitForStreamsWithRetry(t *testing.T) {
log.SetOutput(os.Stderr)
idleCh := make(chan bool, 1)

callsToFetchStreams := 0
fetchStreams := func() (<-chan types.LogStream, <-chan error) {
callsToFetchStreams++
ch := make(chan types.LogStream, 5)
errCh := make(chan error, 1)

if callsToFetchStreams == 2 {
for _, s := range streams {
ch <- s
}
close(ch)
} else {
rnf := &types.ResourceNotFoundException{
Message: new(string),
}
errCh <- rnf
}
return ch, errCh
}
cnt++

fn(&cloudwatchlogs.DescribeLogStreamsOutput{LogStreams: s}, true)
return nil
retry := true
logStreams := &logStreamsType{}
err := initialiseStreams(&retry, idleCh, logStreams, fetchStreams)

assert.Nil(t, err)
assert.Len(t, logStreams.get(), 2)
var streamNames []string
for _, ls := range streams {
streamNames = append(streamNames, *ls.LogStreamName)
}
for _, s := range logStreams.get() {
assert.Contains(t, streamNames, s)
}
}

func (m *mockCloudWatchLogsClientRetry) FilterLogEventsPages(*cloudwatchlogs.FilterLogEventsInput,
func(*cloudwatchlogs.FilterLogEventsOutput, bool) bool) error {
return nil
func TestShortenLogStreamsListIfTooLong(t *testing.T) {

var streams = []types.LogStream{}

size := 105
for i := 0; i < size; i++ {
name := fmt.Sprintf("streams%d", i)
x := &types.LogStream{LogStreamName: aws.String(name)}
streams = append(streams, *x)
}

assert.Len(t, streams, size)
streams = sortLogStreamsByMostRecentEvent(streams)
assert.Len(t, streams, 100)
}
func TestTailWaitForStreamsWithRetry(t *testing.T) {
mockSvc := &mockCloudWatchLogsClientRetry{
streams: streams,

func TestSortLogStreamsByMostRecentEvent(t *testing.T) {

var streams = []types.LogStream{}

size := 105
for i := 0; i < size; i++ {
t := aws.Int64(time.Now().AddDate(0, 0, -i).Unix())
name := fmt.Sprintf("stream%d", i)
x := &types.LogStream{LogStreamName: aws.String(name), LastIngestionTime: t}
streams = append(streams, *x)
}

n := time.Now()
trigger := time.NewTicker(100 * time.Millisecond).C
first := streams[0]
last := streams[size-1]
assert.Greater(t, *first.LastIngestionTime, *last.LastIngestionTime)
streams = sortLogStreamsByMostRecentEvent(streams)

// eventTimestamp := *s.LastEventTimestamp / 1000
// ts := time.Unix(eventTimestamp, 0).Format(timeFormat)

ch, e := Tail(mockSvc, aws.String("logGroup"), aws.String("logStreamName"), aws.Bool(false), aws.Bool(true),
&n, &n, aws.String(""), aws.String(""),
trigger, logger)
assert.NoError(t, e)
// fmt.Println(ch)
assert.NotNil(t, ch)
assert.Len(t, streams, 100)
assert.Equal(t, *streams[len(streams)-1].LogStreamName, "stream0")
first = streams[0]
last = streams[len(streams)-1]
assert.Less(t, *first.LastIngestionTime, *last.LastIngestionTime)
}
30 changes: 12 additions & 18 deletions cloudwatch/lsgroups.go
Original file line number Diff line number Diff line change
@@ -1,39 +1,33 @@
package cloudwatch

import (
"context"
"fmt"
"os"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
)

//LsGroups lists the stream groups
//It returns a channel where stream groups are published
func LsGroups(cwl cloudwatchlogsiface.CloudWatchLogsAPI) <-chan *string {
func LsGroups(cwc *cloudwatchlogs.Client) <-chan *string {
ch := make(chan *string)
params := &cloudwatchlogs.DescribeLogGroupsInput{}

handler := func(res *cloudwatchlogs.DescribeLogGroupsOutput, lastPage bool) bool {
for _, logGroup := range res.LogGroups {
ch <- logGroup.LogGroupName
}
if lastPage {
close(ch)
}
return !lastPage
}

go func() {
err := cwl.DescribeLogGroupsPages(params, handler)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
fmt.Fprintln(os.Stderr, awsErr.Message())
paginator := cloudwatchlogs.NewDescribeLogGroupsPaginator(cwc, params)
for paginator.HasMorePages() {
res, err := paginator.NextPage(context.TODO())
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
close(ch)
}
for _, logGroup := range res.LogGroups {
ch <- logGroup.LogGroupName
}
}
close(ch)
}()
return ch
}
Loading

0 comments on commit 53e878d

Please sign in to comment.