Skip to content

Commit

Permalink
Small flag rename, add ssm-session flag validation (#35)
Browse files Browse the repository at this point in the history
* Small flag rename, add ssm-session flag validation

* Add test for validateSessionFlags()
  • Loading branch information
sendqueery committed Aug 13, 2020
1 parent 9d085e6 commit cf13c79
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
16 changes: 15 additions & 1 deletion cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ func addRunFlags(cmd *cobra.Command) {
cmdutil.AddMaxErrorsFlag(cmd, "0", "Max errors allowed before running on additional targets. Both numbers, such as 10, and percentages, such as 10%, are allowed")
}

func addSessionFlags(cmd *cobra.Command) {
cmdutil.AddTagFlag(cmd)
cmdutil.AddSessionNameFlag(cmd, "ssm-session")
cmdutil.AddLimitFlag(cmd, 10, "Set a limit for the number of instance results returned per profile/region combination.")
}

func getCommandList(cmd *cobra.Command) (commandList []string, err error) {
if commandList, err = cmdutil.GetCommandFlagStringSlice(cmd); err != nil {
return nil, err
Expand Down Expand Up @@ -62,7 +68,7 @@ func getRegionList(cmd *cobra.Command) (regionList []string, err error) {
return regionList, nil
}

func getFilterList(cmd *cobra.Command) (targets []*ssm.Target, err error) {
func getTargetList(cmd *cobra.Command) (targets []*ssm.Target, err error) {
var filterList []string
if filterList, err = cmdutil.GetFlagStringSlice(cmd, "filter"); err != nil {
return nil, err
Expand Down Expand Up @@ -130,6 +136,14 @@ Pattern: %q`, 1, 7, "^([1-9][0-9]*|[0]|[1-9][0-9]%|[0-9]%|100%)$")
return maxErrors, nil
}

func validateSessionFlags(cmd *cobra.Command, instanceList []string, filterList map[string]string) error {
if len(instanceList) > 0 && len(filterList) > 0 {
return cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.")
}

return nil
}

// validateRunFlags validates the usage of certain flags required by the run subcommand
func validateRunFlags(cmd *cobra.Command, instanceList []string, commandList []string, filterList []*ssm.Target) error {
if len(instanceList) > 0 && len(filterList) > 0 {
Expand Down
38 changes: 31 additions & 7 deletions cmd/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ func Test_getRegionList(t *testing.T) {
})
}

func Test_getFilterList(t *testing.T) {
func Test_getTargetList(t *testing.T) {
assert := assert.New(t)
cmd := NewTestCmd()

t.Run("filter flag undefined", func(t *testing.T) {
cmd.Execute()

filterList, err := getFilterList(cmd)
assert.Len(filterList, 0)
targetList, err := getTargetList(cmd)
assert.Len(targetList, 0)
assert.Error(err)

cmd.ResetFlags()
Expand All @@ -202,8 +202,8 @@ func Test_getFilterList(t *testing.T) {
cmd.SetArgs([]string{"-f", "foo=bar"})
cmd.Execute()

filterList, err := getFilterList(cmd)
assert.Len(filterList, 1)
targetList, err := getTargetList(cmd)
assert.Len(targetList, 1)
assert.NoError(err)

cmd.ResetFlags()
Expand All @@ -214,8 +214,8 @@ func Test_getFilterList(t *testing.T) {
cmd.SetArgs([]string{"-f", "foo=bar,baz=bat"})
cmd.Execute()

filterList, err := getFilterList(cmd)
assert.Len(filterList, 2)
targetList, err := getTargetList(cmd)
assert.Len(targetList, 2)
assert.NoError(err)

cmd.ResetFlags()
Expand Down Expand Up @@ -346,6 +346,30 @@ func Test_validateRunFlags(t *testing.T) {
})
}

func Test_validateSessionFlags(t *testing.T) {
assert := assert.New(t)

cmd := NewTestCmd()
cmdutil.AddFilterFlag(cmd)
cmdutil.AddInstanceFlag(cmd)
cmd.Execute()

instanceList := make([]string, 51)

t.Run("try to use --filter and --instance flags", func(t *testing.T) {
filterList := map[string]string{
"foo": "bar",
}
err := validateSessionFlags(cmd, instanceList, filterList)
assert.Error(err)
})

t.Run("valid flag combination", func(t *testing.T) {
err := validateSessionFlags(cmd, instanceList, nil)
assert.NoError(err)
})
}

func Test_setLogLevel(t *testing.T) {
assert := assert.New(t)

Expand Down
7 changes: 3 additions & 4 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func runCommand(cmd *cobra.Command, args []string) {
if commandList, err = getCommandList(cmd); err != nil {
log.Fatal(err)
}
if targets, err = getFilterList(cmd); err != nil {
if targets, err = getTargetList(cmd); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -95,11 +95,10 @@ func runCommand(cmd *cobra.Command, args []string) {
MaxErrors: aws.String(maxErrors),
}

// Set up our AWS session for each permutation of profile + region
sessionPool := session.NewPoolSafe(profileList, regionList, log)

wg, output := sync.WaitGroup{}, invocation.ResultSafe{}

// Set up our AWS session for each permutation of profile + region and iterate over them
sessionPool := session.NewPoolSafe(profileList, regionList, log)
for _, sess := range sessionPool.Sessions {
wg.Add(1)
ssmClient := ssm.New(sess.Session)
Expand Down

0 comments on commit cf13c79

Please sign in to comment.