Skip to content

Commit

Permalink
Adds first working version using RunConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
Gowiem committed May 25, 2020
1 parent 90da0f0 commit c5e9c2b
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 74 deletions.
39 changes: 8 additions & 31 deletions cmd/ecs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package cmd

import (
// "fmt"
"os"

"github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/ecs/ecsiface"
)

// ECSClient is the wrapper around the aws-sdk ECS client and its various structs / methods.
type ECSClient interface {
BuildRunTaskInput() (*ecs.RunTaskInput, error)
BuildRunTaskInput() *ecs.RunTaskInput
RunTask(runTaskInput *ecs.RunTaskInput) (*ecs.RunTaskOutput, error)
}

Expand All @@ -32,19 +32,16 @@ func newClient(client ecsiface.ECSAPI, config *RunConfig) ECSClient {
}
}

func (c *ecsClient) BuildRunTaskInput() (*ecs.RunTaskInput, error) {

taskDefition := c.getTaskDefinition()
assignPublicIP := c.getAssignPublicIp()
func (c *ecsClient) BuildRunTaskInput() *ecs.RunTaskInput {

runInput := &ecs.RunTaskInput{
Cluster: &c.config.Cluster,
TaskDefinition: &taskDefition,
TaskDefinition: &c.config.TaskDefinition,
Count: &c.config.Count,
LaunchType: &c.config.LaunchType,
NetworkConfiguration: &ecs.NetworkConfiguration{
AwsvpcConfiguration: &ecs.AwsVpcConfiguration{
AssignPublicIp: &assignPublicIP,
AssignPublicIp: &c.config.AssignPublicIP,
SecurityGroups: []*string{&c.config.SecurityGroupID},
Subnets: []*string{&c.config.SubnetID},
},
Expand All @@ -53,40 +50,20 @@ func (c *ecsClient) BuildRunTaskInput() (*ecs.RunTaskInput, error) {
ContainerOverrides: []*ecs.ContainerOverride{
{
Command: c.config.Command,
Name: &def,
Name: &c.config.ContainerName,
},
},
},
}

return runInput, nil
return runInput
}

func (c *ecsClient) RunTask(runTaskInput *ecs.RunTaskInput) (*ecs.RunTaskOutput, error) {

output, err := client.RunTask(runInput)
output, err := c.client.RunTask(runTaskInput)
if err != nil {
log.Fatal("Received error when invoking RunTask.", err)
log.Fatal("Error: ", err)
os.Exit(1)
}

log.Info("Output: ", output)
return output, err
}

func (c *ecsClient) getTaskDefinition() string {
if c.config.TaskDefinitionRevision != nil {
return c.config.TaskDefinitionName + ":" + c.config.TaskDefinitionRevision
}

return c.config.TaskDefinitionName
}

func (c *ecsClient) getAssignPublicIP() string {
if c.config.AssignPublicIP {
return ecs.AssignPublicIpEnabled
}

return ecs.AssignPublicIpDisabled
}
93 changes: 50 additions & 43 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"fmt"
"os"

"github.com/aws/aws-sdk-go/aws"
Expand All @@ -25,52 +24,68 @@ var rootCmd = &cobra.Command{
Short: "Easily run one-off tasks against a ECS Cluster",
Long: `
ecsrun is a CLI tool that allows users to run one-off administrative tasks
using their existing Task Definitions.
TODO: Supply more info here.`,
using their existing Task Definitions.`,

Run: func(cmd *cobra.Command, args []string) {
cluster := viper.GetString("cluster")
def := viper.GetString("def")
runCmd := viper.GetString("cmd")
config := BuildRunConfig(awsSession)

ecsClient := NewEcsClient(config)

input := ecsClient.BuildRunTaskInput()
output, err := ecsClient.RunTask(input)
if err != nil {
log.Fatal(err)
}

log.Info("RunTask output: ", output)
},
}

// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd.Execute(); err != nil {
fmt.Println(err)
os.Exit(1)
log.Fatal(err)
}
}

func init() {
log.Debug("Root init.")
cobra.OnInitialize(initConfig, initVerbose, initAws, buildRunConfig)
cobra.OnInitialize(initConfig, initVerbose, initAws)

log.SetOutput(os.Stderr)

// Basic Flags
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.escrun)")
rootCmd.PersistentFlags().BoolP("verbose", "v", false, "verbose output")

// AWS Cred / Environment Flags
rootCmd.PersistentFlags().String("cred", "", "AWS credentials file (default is $HOME/.aws/.credentials)")
rootCmd.PersistentFlags().StringP("profile", "p", "", "AWS profile to target (default is AWS_PROFILE or 'default')")
rootCmd.PersistentFlags().StringP("region", "r", "", `AWS region to target (default is AWS_REGION or pulled from $HOME/.aws/.credentials)`)

rootCmd.PersistentFlags().String("cluster", "", "The ECS Cluster to run the task in.")
rootCmd.PersistentFlags().StringP("def", "d", "", "The ECS Task Definition to use.")
rootCmd.PersistentFlags().StringP("cmd", "c", "", "The ECS Task Definition to use.")

rootCmd.MarkFlagRequired("cluster")
rootCmd.MarkFlagRequired("def")
rootCmd.MarkFlagRequired("cmd")

viper.BindPFlag("profile", rootCmd.PersistentFlags().Lookup("profile"))
viper.BindPFlag("region", rootCmd.PersistentFlags().Lookup("region"))
viper.BindPFlag("cluster", rootCmd.PersistentFlags().Lookup("cluster"))
viper.BindPFlag("def", rootCmd.PersistentFlags().Lookup("def"))
viper.BindPFlag("cmd", rootCmd.PersistentFlags().Lookup("cmd"))
rootCmd.PersistentFlags().String("region", "", `AWS region to target (default is AWS_REGION or pulled from $HOME/.aws/.credentials)`)

// Task Flags
rootCmd.PersistentFlags().StringP("cluster", "c", "", "The ECS Cluster to run the task in.")
rootCmd.PersistentFlags().StringP("task", "t", "", "The name of the ECS Task Definition to use.")
rootCmd.PersistentFlags().StringP("revision", "r", "", "The Task Definition revision to use.")
rootCmd.PersistentFlags().StringP("name", "n", "", "The name of the container in the Task Definition.")
rootCmd.PersistentFlags().StringP("launch-type", "l", "FARGATE", "The launch type to run as. Currently only Fargate is supported.")
rootCmd.PersistentFlags().StringSlice("cmd", []string{}, "The comma separated command override to apply.")
rootCmd.PersistentFlags().Int64("count", 1, "The number of tasks to launch for the given cmd.")

// Network Flags
rootCmd.PersistentFlags().StringP("subnet", "s", "", "The Subnet ID that the task should be launched in.")
rootCmd.PersistentFlags().StringP("security-group", "g", "", "The Security Group ID that the task should be associated with.")
rootCmd.PersistentFlags().Bool("public", false, "Assigns a public IP to the task if included. (default is false)")

// Require specific flags
rootCmd.MarkPersistentFlagRequired("cluster")
rootCmd.MarkPersistentFlagRequired("task")
rootCmd.MarkPersistentFlagRequired("cmd")
rootCmd.MarkPersistentFlagRequired("subnet")
rootCmd.MarkPersistentFlagRequired("security-group")

// Bind em All
viper.BindPFlags(rootCmd.PersistentFlags())
}

// initConfig reads in config file and ENV variables if set.
Expand All @@ -82,8 +97,7 @@ func initConfig() {
// Find home directory.
home, err := homedir.Dir()
if err != nil {
fmt.Println(err)
os.Exit(1)
log.Fatal(err)
}

// Search config in home directory with name ".escrun" (without extension).
Expand All @@ -95,16 +109,14 @@ func initConfig() {

// If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil {
fmt.Println("Using config file:", viper.ConfigFileUsed())
log.Info("Using config file:", viper.ConfigFileUsed())
}
}

func initVerbose() {
verbose, err := rootCmd.PersistentFlags().GetBool("verbose")
if err != nil {
log.Fatal("Unable to pull verbose flag.")
log.Fatal(err)
os.Exit(1)
log.Fatal("Unable to pull verbose flag.", err)
}

if verbose {
Expand All @@ -115,24 +127,25 @@ func initVerbose() {

func initAws() {
profile := getProfile()
viper.Set("profile", profile)

// Create our AWS session object for AWS API Usage
sesh, err := initAwsSession(profile)
if err != nil {
log.Fatal("Unable to init AWS Session. Check your credentials and profile.")
log.Fatal(err)
os.Exit(1)
log.Fatal("Unable to init AWS Session. Check your credentials and profile.", err)
}

region := viper.GetString("region")
if region == "" {
region = *sesh.Config.Region
}

// Override our Session's region in case it was set.
sesh.Config.WithRegion(region)

viper.Set("profile", profile)
viper.Set("region", region)
// Set our awsSession for later use.
// TODO: What's the proper way to do this... This seems weird.
awsSession = sesh
}

func getProfile() string {
Expand All @@ -151,9 +164,7 @@ func getProfile() string {
func initAwsSession(profile string) (*session.Session, error) {
credFile, err := rootCmd.PersistentFlags().GetString("cred")
if err != nil {
log.Fatal("Not able to get credFile from cmd.")
log.Fatal(err)
os.Exit(1)
log.Fatal("Not able to get credFile from cmd.", err)
}

log.Debug("Cred File: " + credFile)
Expand All @@ -176,7 +187,3 @@ func initAwsSession(profile string) (*session.Session, error) {

return sesh, err
}

func buildRunConfig() (*RunConfig) {

}
87 changes: 87 additions & 0 deletions cmd/run_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package cmd

import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecs"
"github.com/spf13/viper"
)

// RunConfig is the main config object used to configure the RunTask.
type RunConfig struct {
Command []*string
Cluster string
TaskDefinition string
TaskDefinitionName string
TaskDefinitionRevision string
ContainerName string
LaunchType string
Count int64

SubnetID string
SecurityGroupID string
AssignPublicIPFlag bool
AssignPublicIP string

Session *session.Session
}

// BuildRunConfig constructs the our primary RunConfig object using the given
// AWS session and the CLI args from viper.
func BuildRunConfig(session *session.Session) *RunConfig {

// Convert our cmd slice to a slice of pointers

cmd := getNormalizedCmd()
taskDef := getTaskDefinition()
name := getContainerName()
assignPublicIP := getAssignPublicIP()

return &RunConfig{
Command: cmd,
Cluster: viper.GetString("cluster"),
TaskDefinition: taskDef,
TaskDefinitionName: viper.GetString("task"),
TaskDefinitionRevision: viper.GetString("revision"),
ContainerName: name,
LaunchType: viper.GetString("launch-type"),
Count: viper.GetInt64("count"),
SubnetID: viper.GetString("subnet"),
SecurityGroupID: viper.GetString("security-group"),
AssignPublicIPFlag: viper.GetBool("public"),
AssignPublicIP: assignPublicIP,
Session: session,
}
}

func getNormalizedCmd() []*string {
result := []*string{}
for _, v := range viper.GetStringSlice("cmd") {
result = append(result, &v)
}

return result
}

func getTaskDefinition() string {
if viper.GetString("revision") != "" {
return viper.GetString("task") + ":" + viper.GetString("revision")
}

return viper.GetString("task")
}

func getContainerName() string {
if viper.GetString("name") != "" {
return viper.GetString("name")
}

return viper.GetString("task")
}

func getAssignPublicIP() string {
if viper.GetBool("public") {
return ecs.AssignPublicIpEnabled
}

return ecs.AssignPublicIpDisabled
}

0 comments on commit c5e9c2b

Please sign in to comment.