Skip to content

Commit

Permalink
Adds config file support
Browse files Browse the repository at this point in the history
  • Loading branch information
Gowiem committed Jun 20, 2020
1 parent 097f841 commit 32dbf30
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 47 deletions.
129 changes: 82 additions & 47 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
package cmd

import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/mitchellh/go-homedir"
"gopkg.in/yaml.v2"

"github.com/sirupsen/logrus"

"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)

var cfgFile string

var log = logrus.New()

var fs = afero.NewOsFs()

var newEcsClient func(*RunConfig) ECSClient

var rootCmd *cobra.Command = &cobra.Command{
Expand All @@ -31,6 +37,12 @@ using their existing Task Definitions.`,
Run: func(cmd *cobra.Command, args []string) {
log.Info("Run!")

if err := initConfigFile(); err != nil {
panic(err)
}

enforceRequired()

config := BuildRunConfig()
ecsClient := newEcsClient(config)

Expand All @@ -56,15 +68,16 @@ func Execute(n func(*RunConfig) ECSClient) {
}

func init() {
cobra.OnInitialize(initConfig, initEnvVars, initRequired, initVerbose, initAws)
cobra.OnInitialize(initEnvVars, initVerbose, initAws)

log.SetOutput(os.Stderr)

// Basic Flags
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "verbose output")

// Config File Flags
rootCmd.PersistentFlags().StringVar(&cfgFile, "config-file", "", "config file (default is $PWD/escrun.yml or $HOME/ecsrun.yml)")
// TODO: Add this back at another time
// rootCmd.PersistentFlags().StringVar(&cfgFile, "config-file", "", "config file (default is $PWD/escrun.yml)")
rootCmd.PersistentFlags().String("config", "default", "config entry to read in the config file (default is 'default')")

// AWS Cred / Environment Flags
Expand All @@ -85,31 +98,8 @@ func init() {
rootCmd.PersistentFlags().StringP("subnet", "s", "", "The Subnet ID that the task should be launched in.")
rootCmd.PersistentFlags().StringP("security-group", "g", "", "The Security Group ID that the task should be associated with.")
rootCmd.PersistentFlags().Bool("public", false, "Assigns a public IP to the task if included. (default is false)")
}

// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
// Use config file from the flag.
viper.SetConfigFile(cfgFile)
} else {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
panic(err)
}

// Search config in home and current directory with name "escrun.yml" (without extension).
viper.AddConfigPath(home)
viper.AddConfigPath(".")
viper.SetConfigType("yaml")
viper.SetConfigName("ecsrun.yml")
}

// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
log.Info("Using config file:", viper.ConfigFileUsed())
}
viper.BindPFlags(rootCmd.PersistentFlags())
}

func initEnvVars() {
Expand All @@ -127,23 +117,6 @@ func initEnvVars() {
viper.AutomaticEnv()
}

func initRequired() {
// NOTE: This is a work around for using required flags with Viper Env Vars
// https://github.com/spf13/viper/issues/397
viper.BindPFlags(rootCmd.PersistentFlags())
rootCmd.Flags().VisitAll(func(f *pflag.Flag) {
if viper.IsSet(f.Name) && viper.GetString(f.Name) != "" {
rootCmd.Flags().Set(f.Name, viper.GetString(f.Name))
}
})

rootCmd.MarkPersistentFlagRequired("cluster")
rootCmd.MarkPersistentFlagRequired("task")
rootCmd.MarkPersistentFlagRequired("cmd")
rootCmd.MarkPersistentFlagRequired("subnet")
rootCmd.MarkPersistentFlagRequired("security-group")
}

func initVerbose() {
verbose, err := rootCmd.PersistentFlags().GetBool("verbose")
if err != nil {
Expand Down Expand Up @@ -203,14 +176,76 @@ func initAwsSession(profile string) (*session.Session, error) {
Credentials: credentials.NewSharedCredentials(credFile, profile),
})
} else {
sesh, err = session.NewSessionWithOptions(session.Options{
sesh = session.Must(session.NewSessionWithOptions(session.Options{
Profile: profile,
SharedConfigState: session.SharedConfigEnable,
Config: aws.Config{
CredentialsChainVerboseErrors: aws.Bool(true),
},
})
}))
}

return sesh, err
}

func initConfigFile() error {
filename, err := findConfigFile()
if err != nil {
return err
}

log.Debug("Using config file: ", filename)

file, err := afero.ReadFile(fs, filename)
if err != nil {
return err
}

config := make(map[string]map[string]interface{})
if err := yaml.Unmarshal(file, &config); err != nil {
return err
}

log.Debug("Full config file contents: ", config)

configEntry := viper.GetString("config")
if err = viper.MergeConfigMap(config[configEntry]); err != nil {
return err
}

return nil
}

func findConfigFile() (string, error) {
supportedExts := []string{"yaml", "yml"}

for _, extension := range supportedExts {
filename := filepath.Join(".", "ecsrun"+"."+extension)
exists, err := afero.Exists(fs, filename)
if err != nil {
log.Fatal("Failed to check if file exists: ", err)
}

if exists {
return filename, nil
}
}
return "", errors.New("config file not found")
}

func enforceRequired() error {
requiredFlags := []string{"cluster", "task", "cmd", "subnet", "security-group"}
unsetFlags := []string{}
for _, flag := range requiredFlags {
if !viper.IsSet(flag) {
unsetFlags = append(unsetFlags, flag)
}
}

if len(unsetFlags) > 0 {
errMsg := fmt.Sprintf("The following are required arguments: %s", strings.Join(unsetFlags, ","))
return errors.New(errMsg)
}

return nil
}
4 changes: 4 additions & 0 deletions ecsrun.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
default:
cluster: "testing"
task: "testing2"
security-group: "sg1"

0 comments on commit 32dbf30

Please sign in to comment.