Skip to content

Commit

Permalink
Adds small updates / fixes to root + tests for root
Browse files Browse the repository at this point in the history
  • Loading branch information
Gowiem committed May 6, 2020
1 parent d2bf185 commit 008e208
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 26 deletions.
68 changes: 47 additions & 21 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package cmd
import (
"fmt"
"os"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
Expand All @@ -46,10 +47,12 @@ var rootCmd = &cobra.Command{
Short: "Easily run one-off tasks against a ECS Cluster",
Long: `
ecsrun is a CLI tool that allows users to run one-off administrative tasks
using their existing ECS Cluster and Task Definitions.
using their existing Task Definitions.
TODO: Supply more info here.`,

Run: func(cmd *cobra.Command, args []string) {},
Run: func(cmd *cobra.Command, args []string) {
log.Info("SHRED!")
},
}

// Execute adds all child commands to the root command and sets flags appropriately.
Expand All @@ -62,22 +65,18 @@ func Execute() {
}

func init() {
cobra.OnInitialize(initConfig)
cobra.OnInitialize(initConfig, initVerbose, initAws)

log.SetOutput(os.Stderr)

rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.escrun.yaml)")
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "verbose output")
rootCmd.PersistentFlags().StringP("cred", "c", "", "aws credentials file (default is $HOME/.aws/.credentials)")
rootCmd.PersistentFlags().StringP("profile", "p", "", "aws profile to target (default is AWS_PROFILE or 'default')")
rootCmd.PersistentFlags().StringP("region", "r", "", `aws region to target (default is AWS_REGION or pulled from $HOME/.aws/.credentials)`)

viper.BindPFlag("profile", rootCmd.PersistentFlags().Lookup("profile"))
viper.BindPFlag("region", rootCmd.PersistentFlags().Lookup("region"))

cred, profile, region := initAws()

viper.Set("profile", profile)
viper.Set("region", region)
viper.Set("accesskey", cred.AccessKeyID)
viper.Set("secretkey", cred.SecretAccessKey)
}

// initConfig reads in config file and ENV variables if set.
Expand Down Expand Up @@ -106,8 +105,22 @@ func initConfig() {
}
}

func initAws() (credentials.Value, string, string) {
profile := getProfile()
func initVerbose() {
verbose, err := rootCmd.PersistentFlags().GetBool("verbose")
if err != nil {
log.Fatal("Unable to pull verbose flag.")
log.Fatal(err)
os.Exit(1)
}

if verbose {
log.Info("Enabling verbose output.")
log.SetLevel(logrus.DebugLevel)
}
}

func initAws() {
profile := getProfile(nil)

// Create our AWS session object for AWS API Usage
sesh, err := initAwsSession(profile)
Expand All @@ -126,42 +139,55 @@ func initAws() (credentials.Value, string, string) {

region := viper.GetString("region")
if region == "" {
region = *awsSession.Config.Region
region = *sesh.Config.Region
}
// Override our Session's region in case it was set.
sesh.Config.WithRegion(region)

return cred, profile, region
viper.Set("profile", profile)
viper.Set("region", region)
viper.Set("accesskey", cred.AccessKeyID)
viper.Set("secretkey", cred.SecretAccessKey)
}

func getProfile() string {
profile := viper.GetString("profile")
func getProfile(t *testing.T) string {
var profile = viper.GetString("profile")
if profile == "" {
profile = "default"
if os.Getenv("AWS_PROFILE") != "" {
profile = os.Getenv("AWS_PROFILE")
}
}

log.Debug("Using AWS Profile: " + profile)
return profile
}

func initAwsSession(profile string) (*session.Session, error) {
credFile, err := rootCmd.Flags().GetString("cred")
credFile, err := rootCmd.PersistentFlags().GetString("cred")
if err != nil {
log.Fatal("Not able to get credFile from cmd.")
log.Fatal(err)
os.Exit(1)
}

log.Debug("Cred File: " + credFile)

var sesh *session.Session

if credFile != "" {
sesh, err := session.NewSession(&aws.Config{
sesh, err = session.NewSession(&aws.Config{
Credentials: credentials.NewSharedCredentials(credFile, profile),
})
return sesh, err
} else {
sesh, err := session.NewSessionWithOptions(session.Options{
sesh, err = session.NewSessionWithOptions(session.Options{
Profile: profile,
SharedConfigState: session.SharedConfigEnable,
Config: aws.Config{
CredentialsChainVerboseErrors: aws.Bool(true),
},
})
return sesh, err
}

return sesh, err
}
47 changes: 42 additions & 5 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,62 @@ import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/spf13/viper"
)

var previous_profile string

func setup() {
previous_profile = os.Getenv("AWS_PROFILE")
os.Unsetenv("AWS_PROFILE")
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
}

func teardown() {
os.Setenv("AWS_PROFILE", previous_profile)
viper.Reset()
}

func TestExecute(t *testing.T) {
setup()
assert := assert.New(t)

os.Setenv("AWS_ACCESS_KEY_ID", "123")
os.Setenv("AWS_SECRET_ACCESS_KEY", "SECRET123")
Execute()
_ = assert

var accessKey = viper.Get("accessKey")
assert.Equal("123", accessKey)

var secretKey = viper.Get("secretKey")
assert.Equal("SECRET123", secretKey)

teardown()
}

func TestInitAws(t *testing.T) {
assert := assert.New(t)
setup()

// TODO
t.Skip("TODO")
assert.Equal(true, true)

teardown()
}

func TestGetProfile(t *testing.T) {
assert := assert.New(t)
setup()

profile = getProfile()
assert.Equal(profile, "default")
var profile1 = getProfile(t)
assert.Equal("default", profile1)

os.Setenv("AWS_PROFILE", "not-default")
profile = getProfile()
assert.Equal(profile, "not-default")
var profile2 = getProfile(t)
assert.Equal("not-default", profile2)

teardown()
}

0 comments on commit 008e208

Please sign in to comment.