diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..acb7976 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,20 @@ +on: [push] +name: tests +jobs: + test: + strategy: + matrix: + go-version: [1.13] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + steps: + - name: Install Go + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go-version }} + - name: Checkout code + uses: actions/checkout@v1 + - name: Lint + run: make lint + - name: Tests + run: make tests diff --git a/.gitignore b/.gitignore index 7b28844..d9b46d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .idea examples/*/* !examples/*/*.go +artifacts diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..403924c --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +all: + +lint-local: + @echo "Running linters" + golangci-lint cache clean + golangci-lint run -v ./... + +lint: + @echo "Running linters" + docker run --rm -v $(PWD):/app -w /app golangci/golangci-lint:v1.21.0 golangci-lint run -v ./... + +tests: + @echo "Running tests" + @mkdir -p artifacts + go test -race -count=1 -cover -coverprofile=artifacts/coverage.out -v ./... + +coverage: tests + @echo "Running tests & coverage" + go tool cover -html=artifacts/coverage.out -o artifacts/coverage.html diff --git a/config.go b/config.go index 22d0e30..2419222 100644 --- a/config.go +++ b/config.go @@ -18,8 +18,12 @@ func (c *config) Value() interface{} { if v.Kind() == reflect.Func { t := v.Type() - if t.NumIn() != 0 && t.NumOut() != 1 { - panic("Function type must have no input parameters and a single return value") + if t.NumIn() != 0 { + panic("Function type must have no input parameters") + } + + if t.NumOut() != 1 { + panic("Function type must have a single return value") } if t.Out(0).Kind().String() != "interface" { diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..e42f5ed --- /dev/null +++ b/config_test.go @@ -0,0 +1,132 @@ +package exec + +import ( + "github.com/stretchr/testify/require" + "reflect" + "testing" + "time" +) + +func TestConfig_Value(t *testing.T) { + type testCase struct { + test string + cfg *config + expectedPanic interface{} + } + + testCases := []testCase{ + { + test: "valid simple value", + cfg: &config{ + value: "", + }, + }, + { + test: "valid func value", + cfg: &config{ + value: func() interface{} { + return "value" + }, + }, + }, + { + test: "panics on invalid func with input param", + cfg: &config{ + value: func(s string) interface{} { + return "value" + }, + }, + expectedPanic: "Function type must have no input parameters", + }, + { + test: "panics on invalid func with more than one return param", + cfg: &config{ + value: func() (interface{}, string) { + return "value", "one" + }, + }, + expectedPanic: "Function type must have a single return value", + }, + { + test: "panics on invalid func with no input param", + cfg: &config{ + value: func() string { + return "value" + }, + }, + expectedPanic: "Function return value must be an interface{}", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic != nil { + require.PanicsWithValue(t, testCase.expectedPanic, func() { + testCase.cfg.Value() + }) + } else { + require.Equal(t, testCase.cfg.value, testCase.cfg.Value()) + } + }) + } +} + +func TestConfig_String(t *testing.T) { + cfg := &config{ + Name: "name", + value: "string", + } + + require.Equal(t, cfg.value, cfg.String()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.String())) +} + +func TestConfig_Int(t *testing.T) { + cfg := &config{ + Name: "name", + value: 1, + } + + require.Equal(t, cfg.value, cfg.Int()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.Int())) +} + +func TestConfig_Int64(t *testing.T) { + cfg := &config{ + Name: "name", + value: int64(1), + } + + require.Equal(t, cfg.value, cfg.Int64()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.Int64())) +} + +func TestConfig_Bool(t *testing.T) { + cfg := &config{ + Name: "name", + value: true, + } + + require.Equal(t, cfg.value, cfg.Bool()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.Bool())) +} + +func TestConfig_Slice(t *testing.T) { + cfg := &config{ + Name: "name", + value: []string{"a", "b"}, + } + + require.Equal(t, cfg.value, cfg.Slice()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.Slice())) +} + +func TestConfig_Time(t *testing.T) { + cfg := &config{ + Name: "name", + value: time.Now(), + } + + require.Equal(t, cfg.value, cfg.Time()) + require.IsType(t, reflect.TypeOf(cfg.value), reflect.TypeOf(cfg.Time())) +} diff --git a/examples/simple/main.go b/examples/simple/main.go index 9bc6186..5a8f3c6 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -3,7 +3,7 @@ package main import ( "fmt" "github.com/fatih/color" - "github.com/go-exec/exec" + e "github.com/go-exec/exec" "time" ) @@ -11,12 +11,15 @@ import ( Example with general setup of tasks */ func main() { + exec := e.Instance + defer exec.Run() + exec.Task("onStart", func() { exec.Set("startTime", time.Now()) }).Private() exec.Task("onEnd", func() { - exec.Println(fmt.Sprintf("Finished in %s!`", time.Now().Sub(exec.Get("startTime").Time()).String())) + exec.Println(fmt.Sprintf("Finished in %s!", time.Since(exec.Get("startTime").Time()).String())) }).Private() type F struct { @@ -25,7 +28,7 @@ func main() { stage := exec.NewArgument("stage", "Provide the running stage") stage.Default = "qa" - stage.Type = exec.String + stage.Type = e.String exec.AddArgument(stage) @@ -78,12 +81,18 @@ func main() { Task("upload", func() { exec.Remote("ls -la /") exec.Upload("test.txt", "~/test.txt") + }). + OnServers(func() []string { + return []string{"prod1"} }) exec. Task("download", func() { exec.Remote("ls -la /") exec.Download("~/test.txt", "test.txt") + }). + OnServers(func() []string { + return []string{"prod1"} }) exec. @@ -266,6 +275,4 @@ func main() { exec.After("local", "onservers:a") exec.After("local", "get3") exec.After("onservers:a", "local") - - exec.Init() } diff --git a/examples/symfony/main.go b/examples/symfony/main.go index a34d138..3c95356 100644 --- a/examples/symfony/main.go +++ b/examples/symfony/main.go @@ -9,6 +9,9 @@ import ( Example of deploying a Symfony app using the deploy recipes */ func main() { + exec := exec.Instance + defer exec.Run() + exec.Set("repository", "git@github.com:namespace/app.git") exec.Set("shared_files", []string{}) exec.Set("shared_dirs", []string{"var/logs", "vendor", "web/uploads", "web/media", "node_modules"}) @@ -38,6 +41,4 @@ func main() { exec.OnServers(func() []string { return []string{exec.GetArgument("stage").String()} }) - - exec.Init() } diff --git a/exec.go b/exec.go index f367a35..ab6554e 100644 --- a/exec.go +++ b/exec.go @@ -10,63 +10,82 @@ import ( "strings" ) -var ( +type Exec struct { // Configs contains all exec context vars used by Get and Set - Configs = make(map[string]*config) + Configs map[string]*config // Tasks contains all exec tasks - Tasks = make(map[string]*task) + Tasks map[string]*task // Servers contains all exec servers - Servers = make(map[string]*server) + Servers map[string]*server // TaskGroups contains all exec task groups - TaskGroups = make(map[string]*taskGroup) + TaskGroups map[string]*taskGroup // Arguments contains all exec arguments - Arguments = make(map[string]*Argument) + Arguments map[string]*Argument // Options contains all exec options - Options = make(map[string]*Option) + Options map[string]*Option // ServerContext is the current active server ServerContext *server // TaskContext is the current executed task TaskContext *task - before = make(map[string][]string) - after = make(map[string][]string) - serverContextF = func() []string { return nil } //must return one server name + before map[string][]string + after map[string][]string + serverContextF func() []string //must return one server name argumentSequence int -) +} + +// New returns a new *Exec instance +func New() *Exec { + return &Exec{ + Configs: make(map[string]*config), + Tasks: make(map[string]*task), + Servers: make(map[string]*server), + TaskGroups: make(map[string]*taskGroup), + Arguments: make(map[string]*Argument), + Options: make(map[string]*Option), + before: make(map[string][]string), + after: make(map[string][]string), + serverContextF: func() []string { return nil }, + } +} + +// Instance is the default empty exported instance of *Exec +// used to be able to create external recipes easily +var Instance = New() -// Init initializes the exec and executes the current command +// Run initializes the exec and executes the current command // should be added to the end of all exec declarations -func Init() { +func (e *Exec) Run() { subtasks := make(map[string]*task) - for name, task := range Tasks { - task.Arguments = mergeArguments(task.removeArguments, Arguments, task.Arguments) - task.Options = mergeOptions(task.removeOptions, Options, task.Options) + for name, task := range e.Tasks { + task.Arguments = mergeArguments(task.removeArguments, e.Arguments, task.Arguments) + task.Options = mergeOptions(task.removeOptions, e.Options, task.Options) if !task.private { subtasks[name] = task } } - for name := range TaskGroups { - TaskGroups[name].task.Arguments = mergeArguments(TaskGroups[name].task.removeArguments, Arguments, TaskGroups[name].task.Arguments) - TaskGroups[name].task.Options = mergeOptions(TaskGroups[name].task.removeOptions, Options, TaskGroups[name].task.Options) - Tasks[name] = TaskGroups[name].task - subtasks[name] = TaskGroups[name].task + for name := range e.TaskGroups { + e.TaskGroups[name].task.Arguments = mergeArguments(e.TaskGroups[name].task.removeArguments, e.Arguments, e.TaskGroups[name].task.Arguments) + e.TaskGroups[name].task.Options = mergeOptions(e.TaskGroups[name].task.removeOptions, e.Options, e.TaskGroups[name].task.Options) + e.Tasks[name] = e.TaskGroups[name].task + subtasks[name] = e.TaskGroups[name].task } - for _, task := range Tasks { - if before[task.Name] != nil { - for _, bt := range before[task.Name] { - if Tasks[bt] != nil { - task.before = append(task.before, Tasks[bt]) + for _, task := range e.Tasks { + if e.before[task.Name] != nil { + for _, bt := range e.before[task.Name] { + if e.Tasks[bt] != nil { + task.before = append(task.before, e.Tasks[bt]) } } } - if after[task.Name] != nil { - for _, at := range after[task.Name] { - if Tasks[at] != nil { - task.after = append(task.after, Tasks[at]) + if e.after[task.Name] != nil { + for _, at := range e.after[task.Name] { + if e.Tasks[at] != nil { + task.after = append(task.after, e.Tasks[at]) } } } @@ -76,20 +95,20 @@ func Init() { subtasks: subtasks, } - rootTask.Arguments = Arguments - rootTask.Options = mergeOptions(map[string]string{}, Options, rootTask.Options) + rootTask.Arguments = e.Arguments + rootTask.Options = mergeOptions(map[string]string{}, e.Options, rootTask.Options) if err := run(&rootTask); err != nil { - fmt.Fprintln(os.Stderr, err) + _, _ = fmt.Fprintln(os.Stderr, err) } else { - for _, s := range Servers { - s.sshClient.Close() + for _, s := range e.Servers { + _ = s.sshClient.Close() } } } // NewArgument returns a new Argument -func NewArgument(name string, description string) *Argument { +func (e *Exec) NewArgument(name string, description string) *Argument { var arg = &Argument{ Name: name, Description: description, @@ -98,17 +117,17 @@ func NewArgument(name string, description string) *Argument { } // AddArgument adds an Argument to exec -func AddArgument(argument *Argument) { - if _, ok := Arguments[argument.Name]; !ok { - argument.sequence = argumentSequence - argumentSequence++ - Arguments[argument.Name] = argument +func (e *Exec) AddArgument(argument *Argument) { + if _, ok := e.Arguments[argument.Name]; !ok { + argument.sequence = e.argumentSequence + e.argumentSequence++ + e.Arguments[argument.Name] = argument } } // GetArgument return an Argument pointer -func GetArgument(name string) *Argument { - if arg, ok := Arguments[name]; ok { +func (e *Exec) GetArgument(name string) *Argument { + if arg, ok := e.Arguments[name]; ok { return arg } @@ -116,7 +135,7 @@ func GetArgument(name string) *Argument { } // NewOption returns a new Option -func NewOption(name string, description string) *Option { +func (e *Exec) NewOption(name string, description string) *Option { var opt = &Option{ Name: name, Description: description, @@ -125,15 +144,15 @@ func NewOption(name string, description string) *Option { } // AddOption adds an Option to exec -func AddOption(option *Option) { - if _, ok := Options[option.Name]; !ok { - Options[option.Name] = option +func (e *Exec) AddOption(option *Option) { + if _, ok := e.Options[option.Name]; !ok { + e.Options[option.Name] = option } } // GetOption return an Option pointer -func GetOption(name string) *Option { - if opt, ok := Options[name]; ok { +func (e *Exec) GetOption(name string) *Option { + if opt, ok := e.Options[name]; ok { return opt } @@ -141,72 +160,73 @@ func GetOption(name string) *Option { } // Set sets a exec Config -func Set(name string, value interface{}) { - Configs[name] = &config{Name: name, value: value} +func (e *Exec) Set(name string, value interface{}) { + e.Configs[name] = &config{Name: name, value: value} } // Get gets a Config value either set in a Server or directly in exec -func Get(name string) *config { - if ServerContext != nil { - if c, ok := ServerContext.Configs[name]; ok { +func (e *Exec) Get(name string) *config { + if e.ServerContext != nil { + if c, ok := e.ServerContext.Configs[name]; ok { return c } } - if c, ok := Configs[name]; ok { + if c, ok := e.Configs[name]; ok { return c } return nil } // Has checks if a Config is available -func Has(name string) bool { - if ServerContext != nil { - if _, ok := ServerContext.Configs[name]; ok { +func (e *Exec) Has(name string) bool { + if e.ServerContext != nil { + if _, ok := e.ServerContext.Configs[name]; ok { return true } } - _, ok := Configs[name] + _, ok := e.Configs[name] return ok } // Server adds a new Server to exec // dsn should be user@host:port -func Server(name string, dsn string) *server { - Servers[name] = &server{ +func (e *Exec) Server(name string, dsn string) *server { + e.Servers[name] = &server{ Name: name, Dsn: dsn, Configs: make(map[string]*config), sshClient: &sshClient{}, } - return Servers[name] + return e.Servers[name] } // Task inherits the exec Arguments and can override and/or have new Options // it accepts a name and a func; the func content is executed on each command execution -func Task(name string, f func()) *task { - Tasks[name] = &task{ +func (e *Exec) Task(name string, f func()) *task { + e.Tasks[name] = &task{ Name: name, Arguments: make(map[string]*Argument), Options: make(map[string]*Option), + exec: e, removeArguments: make(map[string]string), removeOptions: make(map[string]string), serverContextF: func() []string { return nil }, } - Tasks[name].run = func() { + e.Tasks[name].run = func() { // set task context - TaskContext = Tasks[name] + e.TaskContext = e.Tasks[name] - run, onServers := shouldIRun() + run, onServers := e.shouldIRun() //skip tasks's server checking if requested if run && len(onServers) > 0 { - for _, server := range Servers { + for _, server := range e.Servers { for _, onServer := range onServers { - if (server.Name == onServer || server.HasRole(onServer)) && Servers[onServer] != nil { + if (server.Name == onServer || server.HasRole(onServer)) && e.Servers[onServer] != nil { // set server context - ServerContext = server + e.ServerContext = server color.White("➤ Executing task %s on server %s", color.YellowString(name), color.GreenString(fmt.Sprintf("[%s]", server.Name))) @@ -214,7 +234,7 @@ func Task(name string, f func()) *task { f() //reset server context - ServerContext = nil + e.ServerContext = nil } } } @@ -225,57 +245,58 @@ func Task(name string, f func()) *task { //execute task's func f() } else { - taskNotAllowedToRunPrint(onServers, name) + e.taskNotAllowedToRunPrint(onServers, name) } //reset task context - TaskContext = nil + e.TaskContext = nil } - return Tasks[name] + return e.Tasks[name] } // TaskGroup inherits the exec Arguments and can override and/or have new Options // and it will run all associated tasks -func TaskGroup(name string, tasks ...string) *taskGroup { - TaskGroups[name] = &taskGroup{ +func (e *Exec) TaskGroup(name string, tasks ...string) *taskGroup { + e.TaskGroups[name] = &taskGroup{ Name: name, task: &task{ Name: name, removeArguments: make(map[string]string), removeOptions: make(map[string]string), + exec: e, run: func() { color.White("➤ Executing task group %s", color.YellowString(name)) for _, task := range tasks { - if Tasks[task] == nil { + if e.Tasks[task] == nil { continue } - if Tasks[task].once && Tasks[task].executedOnce { + if e.Tasks[task].once && e.Tasks[task].executedOnce { continue } //set task context - TaskContext = Tasks[task] + e.TaskContext = e.Tasks[task] - Tasks[task].run() + e.Tasks[task].run() - if Tasks[task].once && !Tasks[task].executedOnce { - Tasks[task].executedOnce = true + if e.Tasks[task].once && !e.Tasks[task].executedOnce { + e.Tasks[task].executedOnce = true } //reset task context - TaskContext = nil + e.TaskContext = nil } }, }, } - TaskGroups[name].tasks = append(TaskGroups[name].tasks, tasks...) - return TaskGroups[name] + e.TaskGroups[name].tasks = append(e.TaskGroups[name].tasks, tasks...) + return e.TaskGroups[name] } // Local runs a local command and displays/returns the output for further usage, for example in a Task func -func Local(command string, args ...interface{}) (o output) { - command = Parse(fmt.Sprintf(command, args...)) +func (e *Exec) Local(command string, args ...interface{}) (o Output) { + command = e.Parse(fmt.Sprintf(command, args...)) color.Green("[%s] %s %s", "local", ">", color.WhiteString("`%s`", command)) @@ -326,7 +347,7 @@ func Local(command string, args ...interface{}) (o output) { } } - o.text = strings.TrimSpace(string(output)) + o.text = strings.TrimSpace(output) if len(o.text) == 0 { color.Red("[%s] %s\n", "local", "<") @@ -343,19 +364,87 @@ func Local(command string, args ...interface{}) (o output) { } // Println parses a text template, if founds a {{ var }}, it automatically runs the Get(var) on it -func Println(text string) { - fmt.Println(Parse(text)) +func (e *Exec) Println(text string) { + fmt.Println(e.Parse(text)) } // OnServers sets the server context dynamically -func OnServers(f func() []string) { - serverContextF = f +func (e *Exec) OnServers(f func() []string) { + e.serverContextF = f +} + +// Remote runs a command with args, in the ServerContext +func (e *Exec) Remote(command string, args ...interface{}) (o Output) { + run, onServers := e.shouldIRun() + + if !run { + e.commandNotAllowedToRunPrint(onServers, fmt.Sprintf(command, args...)) + return o + } + + if e.ServerContext != nil { + return e.remoteRun(fmt.Sprintf(command, args...), e.ServerContext) + } + + return o +} + +// Upload uploads a file or directory from local to remote, using native scp binary +func (e *Exec) Upload(local, remote string) { + run, onServers := e.shouldIRun() + + if !run { + e.commandNotAllowedToRunPrint(onServers, fmt.Sprintf("scp (local)%s > (remote)%s", local, remote)) + } + + var args = []string{"scp", "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -r"} + if e.ServerContext.key != nil { + args = append(args, "-i "+*e.ServerContext.key) + } + args = append(args, local, e.ServerContext.Dsn+":"+remote) + + e.Local(strings.Join(args, " ")) +} + +// Download downloads a file or directory from remote to local, using native scp binary +func (e *Exec) Download(remote, local string) { + run, onServers := e.shouldIRun() + + if !run { + e.commandNotAllowedToRunPrint(onServers, fmt.Sprintf("scp (remote)%s > (local)%s", local, remote)) + } + + var args = []string{"scp", "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -r"} + if e.ServerContext.key != nil { + args = append(args, "-i "+*e.ServerContext.key) + } + args = append(args, e.ServerContext.Dsn+":"+remote, local) + + e.Local(strings.Join(args, " ")) +} + +// Before sets tasks to run before task +func (e *Exec) Before(task string, tasksBefore ...string) { + for _, tb := range tasksBefore { + if !contains(e.before[task], tb) { + e.before[task] = append(e.before[task], tb) + } + } +} + +// After sets tasks to run after task +func (e *Exec) After(task string, tasksAfter ...string) { + for _, ta := range tasksAfter { + if !contains(e.after[task], ta) { + e.after[task] = append(e.after[task], ta) + } + } } -// RemoteRun executes a command on a specific server -func RemoteRun(command string, server *server) (o output) { - ServerContext = server - command = Parse(command) +// remoteRun executes a command on a specific server +func (e *Exec) remoteRun(command string, server *server) (o Output) { + e.ServerContext = server + command = e.Parse(command) color.Green("[%s] %s %s", server.Name, ">", color.WhiteString("`%s`", command)) @@ -416,104 +505,36 @@ func RemoteRun(command string, server *server) (o output) { return o } -// Remote runs a command with args, in the ServerContext -func Remote(command string, args ...interface{}) (o output) { - run, onServers := shouldIRun() - - if !run { - commandNotAllowedToRunPrint(onServers, fmt.Sprintf(command, args...)) - return o - } - - if ServerContext != nil { - return RemoteRun(fmt.Sprintf(command, args...), ServerContext) - } - - return o -} - -// Upload uploads a file or directory from local to remote, using native scp binary -func Upload(local, remote string) { - run, onServers := shouldIRun() - - if !run { - commandNotAllowedToRunPrint(onServers, fmt.Sprintf("scp (local)%s > (remote)%s", local, remote)) - } - - var args = []string{"scp", "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -r"} - if ServerContext.key != nil { - args = append(args, "-i "+*ServerContext.key) - } - args = append(args, local, ServerContext.Dsn+":"+remote) - - Local(strings.Join(args, " ")) -} - -// Download downloads a file or directory from remote to local, using native scp binary -func Download(remote, local string) { - run, onServers := shouldIRun() - - if !run { - commandNotAllowedToRunPrint(onServers, fmt.Sprintf("scp (remote)%s > (local)%s", local, remote)) - } - - var args = []string{"scp", "-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -r"} - if ServerContext.key != nil { - args = append(args, "-i "+*ServerContext.key) - } - args = append(args, ServerContext.Dsn+":"+remote, local) - - Local(strings.Join(args, " ")) -} - -// Before sets tasks to run before task -func Before(task string, tasksBefore ...string) { - for _, tb := range tasksBefore { - if !contains(before[task], tb) { - before[task] = append(before[task], tb) - } - } -} - -// After sets tasks to run after task -func After(task string, tasksAfter ...string) { - for _, ta := range tasksAfter { - if !contains(after[task], ta) { - after[task] = append(after[task], ta) - } - } -} - -func shouldIRun() (run bool, onServers []string) { +func (e *Exec) shouldIRun() (run bool, onServers []string) { run = true //default values if serverContextF is set - if s := serverContextF(); len(s) > 0 { + if s := e.serverContextF(); len(s) > 0 { onServers = s } //inside a task - if TaskContext != nil { + if e.TaskContext != nil { //task has a serverContextF - if s := TaskContext.serverContextF(); len(s) > 0 { + if s := e.TaskContext.serverContextF(); len(s) > 0 { onServers = s } //task needs to run only on some servers - if len(TaskContext.onlyOnServers) > 0 { + if len(e.TaskContext.onlyOnServers) > 0 { run = false for _, oS := range onServers { - for _, oOS := range TaskContext.onlyOnServers { + for _, oOS := range e.TaskContext.onlyOnServers { //task on server matches only on servers if oS == oOS { run = true } } } - onServers = TaskContext.onlyOnServers + onServers = e.TaskContext.onlyOnServers } - if TaskContext.once && TaskContext.executedOnce { + if e.TaskContext.once && e.TaskContext.executedOnce { run = false } } @@ -521,24 +542,24 @@ func shouldIRun() (run bool, onServers []string) { return run, onServers } -func commandNotAllowedToRunPrint(onServers []string, command string) { +func (e *Exec) commandNotAllowedToRunPrint(onServers []string, command string) { fmt.Printf("%s%s%s\n", color.CyanString("[local] > Command `"), color.WhiteString(command), color.CyanString("` can run only on %s", onServers)) } -func taskNotAllowedToRunPrint(onServers []string, task string) { +func (e *Exec) taskNotAllowedToRunPrint(onServers []string, task string) { fmt.Printf("%s%s%s\n", color.CyanString("[local] > Task `"), color.WhiteString(task), color.CyanString("` can run only on %s", onServers)) } // onStart task setup -func onStart() { - if task, ok := Tasks["onStart"]; ok { +func (e *Exec) onStart() { + if task, ok := e.Tasks["onStart"]; ok { task.run() } } // onEnd task setup -func onEnd() { - if task, ok := Tasks["onEnd"]; ok { +func (e *Exec) onEnd() { + if task, ok := e.Tasks["onEnd"]; ok { task.run() } } diff --git a/exec_test.go b/exec_test.go new file mode 100644 index 0000000..1d5d0e5 --- /dev/null +++ b/exec_test.go @@ -0,0 +1,455 @@ +package exec + +import ( + "github.com/go-exec/exec/ssh_mock" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNew(t *testing.T) { + e := New() + + require.IsType(t, &Exec{}, e) +} + +func TestExec_NewArgument(t *testing.T) { + e := New() + + arg := &Argument{ + Name: "name", + Type: 0, + Default: nil, + Multiple: false, + Description: "description", + Value: nil, + } + + require.Equal(t, e.NewArgument(arg.Name, arg.Description), arg) +} + +func TestExec_AddArgument(t *testing.T) { + e := New() + + arg := &Argument{ + Name: "test", + Type: 0, + Default: nil, + Multiple: false, + Description: "", + Value: nil, + } + e.AddArgument(arg) + + require.Equal(t, arg, e.Arguments[arg.Name]) +} + +func TestExec_GetArgument(t *testing.T) { + type testCase struct { + test string + name string + arg *Argument + } + + testCases := []testCase{ + { + test: "valid argument", + name: "valid", + arg: &Argument{ + Name: "valid", + }, + }, + { + test: "invalid argument", + name: "invalid", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + if testCase.arg != nil { + e.AddArgument(testCase.arg) + } + + require.Equal(t, e.GetArgument(testCase.name), testCase.arg) + }) + } +} + +func TestExec_NewOption(t *testing.T) { + e := New() + + opt := &Option{ + Name: "name", + Type: 0, + Default: nil, + Description: "description", + Value: nil, + } + require.Equal(t, e.NewOption(opt.Name, opt.Description), opt) +} + +func TestExec_AddOption(t *testing.T) { + e := New() + + opt := &Option{ + Name: "name", + Type: 0, + Default: nil, + Description: "description", + Value: nil, + } + e.AddOption(opt) + + require.Equal(t, opt, e.Options[opt.Name]) +} + +func TestExec_GetOption(t *testing.T) { + type testCase struct { + test string + name string + opt *Option + } + + testCases := []testCase{ + { + test: "valid option", + name: "valid", + opt: &Option{ + Name: "valid", + }, + }, + { + test: "invalid option", + name: "invalid", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + if testCase.opt != nil { + e.AddOption(testCase.opt) + } + + require.Equal(t, e.GetOption(testCase.name), testCase.opt) + }) + } +} + +func TestExec_Set(t *testing.T) { + e := New() + + cfg := &config{ + Name: "cfg", + value: "val", + } + e.Set(cfg.Name, cfg.value) + + require.Equal(t, cfg, e.Configs[cfg.Name]) +} + +func TestExec_Get(t *testing.T) { + type testCase struct { + test string + name string + cfg *config + serverCtx *server + } + + testCases := []testCase{ + { + test: "valid cfg", + name: "valid", + cfg: &config{ + Name: "valid", + }, + }, + { + test: "invalid cfg", + name: "invalid", + }, + { + test: "valid cfg in server ctx", + name: "valid", + cfg: &config{ + Name: "valid", + }, + serverCtx: &server{}, + }, + { + test: "invalid cfg in server ctx", + name: "invalid", + serverCtx: &server{}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + if testCase.cfg != nil { + e.Set(testCase.cfg.Name, testCase.cfg.value) + } + + if testCase.serverCtx != nil { + e.ServerContext = testCase.serverCtx + } + + require.Equal(t, e.Get(testCase.name), testCase.cfg) + }) + } +} + +func TestExec_Has(t *testing.T) { + type testCase struct { + test string + name string + cfg *config + serverCtx *server + expectedResult bool + } + + testCases := []testCase{ + { + test: "valid cfg", + name: "valid", + cfg: &config{ + Name: "valid", + }, + expectedResult: true, + }, + { + test: "invalid cfg", + name: "invalid", + expectedResult: false, + }, + { + test: "valid cfg in server ctx", + name: "valid", + cfg: &config{ + Name: "valid", + }, + serverCtx: &server{}, + expectedResult: true, + }, + { + test: "invalid cfg in server ctx", + name: "invalid", + serverCtx: &server{}, + expectedResult: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + if testCase.cfg != nil { + e.Set(testCase.cfg.Name, testCase.cfg.value) + } + + if testCase.serverCtx != nil { + e.ServerContext = testCase.serverCtx + } + + require.Equal(t, e.Has(testCase.name), testCase.expectedResult) + }) + } +} + +func TestExec_Server(t *testing.T) { + e := New() + + cfg := &server{ + Name: "server", + Dsn: "user@host:port", + Configs: make(map[string]*config), + sshClient: &sshClient{}, + } + e.Server(cfg.Name, cfg.Dsn) + + require.Equal(t, cfg, e.Servers[cfg.Name]) +} + +func TestExec_Task(t *testing.T) { + e := New() + + task := &task{ + Name: "task", + Arguments: make(map[string]*Argument), + Options: make(map[string]*Option), + exec: e, + removeArguments: make(map[string]string), + removeOptions: make(map[string]string), + } + e.Task(task.Name, func() {}) + + require.Contains(t, e.Tasks, task.Name) + require.Equal(t, task.exec, e.Tasks[task.Name].exec) +} + +func TestExec_TaskGroup(t *testing.T) { + e := New() + + taskGroup := &taskGroup{ + Name: "taskGroup", + } + e.TaskGroup(taskGroup.Name) + + require.Contains(t, e.TaskGroups, taskGroup.Name) + require.Equal(t, e.TaskGroups[taskGroup.Name].task.exec, e) +} + +func TestExec_Before(t *testing.T) { + type testCase struct { + test string + task *task + before []string + unique int + } + + testCases := []testCase{ + { + test: "valid", + task: &task{ + Name: "task", + }, + before: []string{"before 1"}, + unique: 1, + }, + { + test: "valid with unique before items", + task: &task{ + Name: "task", + }, + before: []string{"before 1", "before 2", "before 1"}, + unique: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + e.Before(testCase.task.Name, testCase.before...) + + require.Contains(t, e.before, testCase.task.Name) + require.Equal(t, len(e.before[testCase.task.Name]), testCase.unique) + }) + } +} + +func TestExec_After(t *testing.T) { + type testCase struct { + test string + task *task + after []string + unique int + } + + testCases := []testCase{ + { + test: "valid", + task: &task{ + Name: "task", + }, + after: []string{"after 1"}, + unique: 1, + }, + { + test: "valid with unique after items", + task: &task{ + Name: "task", + }, + after: []string{"after 1", "after 2", "after 1"}, + unique: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + e := New() + + e.Before(testCase.task.Name, testCase.after...) + + require.Contains(t, e.before, testCase.task.Name) + require.Equal(t, len(e.before[testCase.task.Name]), testCase.unique) + }) + } +} + +func TestExec_Remote(t *testing.T) { + type args struct { + command string + args []interface{} + } + testCases := []struct { + name string + args args + wantO Output + }{ + { + name: "echo remote test", + args: args{ + command: `echo hello`, + }, + wantO: Output{ + text: "hello", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + server := ssh_mock.NewServer(t) + defer server.Shutdown() + conn := server.Dial(ssh_mock.ClientConfig()) + defer conn.Close() + + e := New() + + s := e.Server("mock", "") + + s.sshClient.WithConnection(conn) + + e.ServerContext = s + + gotO := e.Remote(testCase.args.command, testCase.args.args...) + + require.Equal(t, testCase.wantO, gotO, "Remote() = %v, want %v", gotO, testCase.wantO) + }) + } +} + +func TestExec_Local(t *testing.T) { + type args struct { + command string + args []interface{} + } + testCases := []struct { + name string + args args + wantO Output + }{ + { + name: "echo local test", + args: args{ + command: `echo hello`, + }, + wantO: Output{ + text: "hello", + }, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + e := New() + + gotO := e.Local(testCase.args.command, testCase.args.args...) + + require.Equal(t, testCase.wantO, gotO, "Remote() = %v, want %v", gotO, testCase.wantO) + }) + } +} \ No newline at end of file diff --git a/go.mod b/go.mod index 799d4c7..ac55d3e 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/mattn/go-isatty v0.0.4 // indirect github.com/pkg/errors v0.8.1 github.com/satori/go.uuid v1.2.0 + github.com/stretchr/testify v1.4.0 golang.org/x/crypto v0.0.0-20170118185426-b8a2a83acfe6 golang.org/x/sys v0.0.0-20161214190518-d75a52659825 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect diff --git a/go.sum b/go.sum index cff569e..8d78f2a 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -11,11 +13,20 @@ github.com/mattn/go-isatty v0.0.4 h1:bnP0vzxcAdeI1zdubAl5PjU6zsERjGZb7raWodagDYs github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= golang.org/x/crypto v0.0.0-20170118185426-b8a2a83acfe6 h1:cwnjxMgUhW6Oz2++KLc+loQIC0/qUZL1PHXWLuiyCmc= golang.org/x/crypto v0.0.0-20170118185426-b8a2a83acfe6/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/sys v0.0.0-20161214190518-d75a52659825 h1:4d9VvrP9mESHxCpAwE1G5e1D8Ybj9v7pX19HkGQV0lk= golang.org/x/sys v0.0.0-20161214190518-d75a52659825/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/options.go b/options.go index fe3257f..f9eabc0 100644 --- a/options.go +++ b/options.go @@ -73,7 +73,7 @@ func (arg Argument) Explain() string { // String casts a value to a string and panics on failure. func (arg Argument) String() string { - return arg.Value.(string) + return *arg.Value.(*string) } // Bool casts a value to a bool and panics on failure. diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..c04c41c --- /dev/null +++ b/options_test.go @@ -0,0 +1,363 @@ +package exec + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestOption_Explain(t *testing.T) { + type testCase struct { + test string + opt *Option + expectedResult string + } + + testCases := []testCase{ + { + test: "single char option", + opt: &Option{ + Name: "r", + }, + expectedResult: "-r", + }, + { + test: "multi char option", + opt: &Option{ + Name: "run", + }, + expectedResult: "--run", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + require.Equal(t, testCase.expectedResult, testCase.opt.Explain()) + }) + } +} + +func TestOption_String(t *testing.T) { + type testCase struct { + test string + opt *Option + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value string type pointer", + opt: &Option{ + Name: "option", + Type: String, + Value: new(string), + }, + }, + { + test: "invalid value string type with no pointer", + opt: &Option{ + Name: "option", + Type: String, + Value: "string", + }, + expectedPanic: true, + }, + { + test: "invalid value int type with no pointer", + opt: &Option{ + Name: "option", + Type: String, + Value: 0, + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + _ = testCase.opt.String() + }) + } else { + require.Equal(t, *testCase.opt.Value.(*string), testCase.opt.String()) + } + }) + } +} + +func TestOption_Int(t *testing.T) { + type testCase struct { + test string + opt *Option + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value int type pointer", + opt: &Option{ + Name: "option", + Type: Int, + Value: new(int), + }, + }, + { + test: "invalid value int type with no pointer", + opt: &Option{ + Name: "option", + Type: Int, + Value: 0, + }, + expectedPanic: true, + }, + { + test: "invalid value string type with no pointer", + opt: &Option{ + Name: "option", + Type: Int, + Value: "string", + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + testCase.opt.Int() + }) + } else { + require.Equal(t, *testCase.opt.Value.(*int), testCase.opt.Int()) + } + }) + } +} + +func TestOption_Bool(t *testing.T) { + type testCase struct { + test string + opt *Option + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value bool type pointer", + opt: &Option{ + Name: "option", + Type: Bool, + Value: new(bool), + }, + }, + { + test: "invalid value bool type with no pointer", + opt: &Option{ + Name: "option", + Type: Bool, + Value: false, + }, + expectedPanic: true, + }, + { + test: "invalid value string type with no pointer", + opt: &Option{ + Name: "option", + Type: Bool, + Value: "string", + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + testCase.opt.Bool() + }) + } else { + require.Equal(t, *testCase.opt.Value.(*bool), testCase.opt.Bool()) + } + }) + } +} + +func TestArgument_Explain(t *testing.T) { + type testCase struct { + test string + arg *Argument + expectedResult string + } + + testCases := []testCase{ + { + test: "single arg", + arg: &Argument{ + Name: "arg", + }, + expectedResult: "", + }, + { + test: "multi arg", + arg: &Argument{ + Name: "arg", + Multiple: true, + }, + expectedResult: "...", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + require.Equal(t, testCase.expectedResult, testCase.arg.Explain()) + }) + } +} + +func TestArgument_String(t *testing.T) { + type testCase struct { + test string + arg *Argument + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value string type pointer", + arg: &Argument{ + Name: "argument", + Type: String, + Value: new(string), + }, + }, + { + test: "invalid value string type with no pointer", + arg: &Argument{ + Name: "argument", + Type: String, + Value: "string", + }, + expectedPanic: true, + }, + { + test: "invalid value int type with no pointer", + arg: &Argument{ + Name: "argument", + Type: String, + Value: 0, + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + _ = testCase.arg.String() + }) + } else { + require.Equal(t, *testCase.arg.Value.(*string), testCase.arg.String()) + } + }) + } +} + +func TestArgument_Int(t *testing.T) { + type testCase struct { + test string + arg *Argument + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value int type pointer", + arg: &Argument{ + Name: "argument", + Type: Int, + Value: new(int), + }, + }, + { + test: "invalid value int type with no pointer", + arg: &Argument{ + Name: "argument", + Type: Int, + Value: 0, + }, + expectedPanic: true, + }, + { + test: "invalid value string type with no pointer", + arg: &Argument{ + Name: "argument", + Type: Int, + Value: "string", + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + testCase.arg.Int() + }) + } else { + require.Equal(t, *testCase.arg.Value.(*int), testCase.arg.Int()) + } + }) + } +} + +func TestArgument_Bool(t *testing.T) { + type testCase struct { + test string + arg *Argument + expectedPanic bool + } + + testCases := []testCase{ + { + test: "valid with value bool type pointer", + arg: &Argument{ + Name: "argument", + Type: Bool, + Value: new(bool), + }, + }, + { + test: "invalid value bool type with no pointer", + arg: &Argument{ + Name: "argument", + Type: Bool, + Value: false, + }, + expectedPanic: true, + }, + { + test: "invalid value string type with no pointer", + arg: &Argument{ + Name: "argument", + Type: Bool, + Value: "string", + }, + expectedPanic: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + if testCase.expectedPanic { + require.Panics(t, func() { + testCase.arg.Bool() + }) + } else { + require.Equal(t, *testCase.arg.Value.(*bool), testCase.arg.Bool()) + } + }) + } +} diff --git a/output.go b/output.go index e66ba46..1765c37 100644 --- a/output.go +++ b/output.go @@ -5,24 +5,20 @@ import ( "strings" ) -type output struct { +type Output struct { text string err error } -func (o output) HasError() bool { - if o.err != nil { - return true - } - - return false +func (o Output) HasError() bool { + return o.err != nil } -func (o output) String() string { +func (o Output) String() string { return o.text } -func (o output) Int() int { +func (o Output) Int() int { i, err := strconv.Atoi(o.text) if err == nil { return i @@ -30,14 +26,10 @@ func (o output) Int() int { return 0 } -func (o output) Bool() bool { - if "true" == o.text { - return true - } - - return false +func (o Output) Bool() bool { + return "true" == o.text } -func (o output) Slice(sep string) []string { +func (o Output) Slice(sep string) []string { return strings.Split(o.text, sep) } diff --git a/output_test.go b/output_test.go new file mode 100644 index 0000000..94fd76f --- /dev/null +++ b/output_test.go @@ -0,0 +1,117 @@ +package exec + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" +) + +func TestOutput_HasError(t *testing.T) { + o := &Output{ + err: errors.New(""), + } + + require.True(t, o.HasError()) +} + +func TestOutput_String(t *testing.T) { + o := &Output{} + + require.True(t, o.String() == o.text) +} + +func TestOutput_Int(t *testing.T) { + type testCase struct { + test string + output *Output + expected int + } + + testCases := []testCase{ + { + test: "valid string to int via atoi", + output: &Output{ + text: "1", + }, + expected: 1, + }, + { + test: "invalid string to int via atoi", + output: &Output{ + text: "string", + }, + expected: 0, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + require.Equal(t, testCase.expected, testCase.output.Int()) + }) + } +} + +func TestOutput_Bool(t *testing.T) { + type testCase struct { + test string + output *Output + expected bool + } + + testCases := []testCase{ + { + test: "valid true string to bool", + output: &Output{ + text: "true", + }, + expected: true, + }, + { + test: "valid !true string to bool", + output: &Output{ + text: "string", + }, + expected: false, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + require.Equal(t, testCase.expected, testCase.output.Bool()) + }) + } +} + +func TestOutput_Slice(t *testing.T) { + type testCase struct { + test string + output *Output + separator string + expected []string + } + + testCases := []testCase{ + { + test: "valid splice a,b,c", + output: &Output{ + text: "a,b,c", + }, + separator: ",", + expected: []string{"a", "b", "c"}, + }, + { + test: "valid splice empty", + output: &Output{ + text: "", + }, + separator: ",", + expected: []string{""}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.test, func(t *testing.T) { + require.Equal(t, testCase.expected, testCase.output.Slice(testCase.separator)) + }) + } +} diff --git a/recipes/deploy/cleanup.go b/recipes/deploy/cleanup.go index 01a966d..cc59a9c 100644 --- a/recipes/deploy/cleanup.go +++ b/recipes/deploy/cleanup.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("cleanup", func() { releases := exec.Get("releases_list").Slice() keep := exec.Get("keep_releases").Int() diff --git a/recipes/deploy/clear_paths.go b/recipes/deploy/clear_paths.go index d7d9a74..28c1081 100644 --- a/recipes/deploy/clear_paths.go +++ b/recipes/deploy/clear_paths.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("deploy:clear_paths", func() { paths := exec.Get("clear_paths").Slice() sudo := "" diff --git a/recipes/deploy/copy_dirs.go b/recipes/deploy/copy_dirs.go index 3e914db..9cfa245 100644 --- a/recipes/deploy/copy_dirs.go +++ b/recipes/deploy/copy_dirs.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("deploy:copy_dirs", func() { dirs := exec.Get("copy_dirs").Slice() diff --git a/recipes/deploy/defaults.go b/recipes/deploy/defaults.go index 66b1728..575bae6 100644 --- a/recipes/deploy/defaults.go +++ b/recipes/deploy/defaults.go @@ -2,7 +2,7 @@ package deploy import ( "fmt" - "github.com/go-exec/exec" + e "github.com/go-exec/exec" "regexp" "strconv" "strings" @@ -10,6 +10,7 @@ import ( ) func init() { + exec := e.Instance exec.Set("keep_releases", 5) exec.Set("repository", "") // Repository to deploy. @@ -67,15 +68,15 @@ func init() { }) branch := exec.NewOption("branch", "Branch to deploy") - branch.Type = exec.String + branch.Type = e.String exec.AddOption(branch) tag := exec.NewOption("tag", "Tag to deploy") - tag.Type = exec.String + tag.Type = e.String exec.AddOption(tag) revision := exec.NewOption("revision", "Revision to deploy") - revision.Type = exec.String + revision.Type = e.String exec.AddOption(revision) exec.Task("current", func() { @@ -101,7 +102,7 @@ func init() { }).Once().Private() exec.Task("onEnd", func() { - exec.Println(fmt.Sprintf("Finished in %s!", time.Now().Sub(exec.Get("startTime").Time()).String())) + exec.Println(fmt.Sprintf("Finished in %s!", time.Since(exec.Get("startTime").Time()).String())) exec.Println("End") }).Once().Private() } diff --git a/recipes/deploy/lock.go b/recipes/deploy/lock.go index 112bc4f..a3d53d8 100644 --- a/recipes/deploy/lock.go +++ b/recipes/deploy/lock.go @@ -3,6 +3,7 @@ package deploy import "github.com/go-exec/exec" func init() { + exec := exec.Instance exec.Task("deploy:lock", func() { locked := exec.Remote("if [ -f {{deploy_path}}/.dep/deploy.lock ]; then echo 'true'; fi").Bool() diff --git a/recipes/deploy/prepare.go b/recipes/deploy/prepare.go index 71bd41d..74d6a97 100644 --- a/recipes/deploy/prepare.go +++ b/recipes/deploy/prepare.go @@ -3,6 +3,7 @@ package deploy import "github.com/go-exec/exec" func init() { + exec := exec.Instance exec.Task("deploy:prepare", func() { exec.Remote("if [ ! -d {{deploy_path}} ]; then mkdir -p {{deploy_path}}; fi") diff --git a/recipes/deploy/release.go b/recipes/deploy/release.go index ca02d98..c7f9a73 100644 --- a/recipes/deploy/release.go +++ b/recipes/deploy/release.go @@ -9,6 +9,7 @@ import ( ) func init() { + exec := exec.Instance exec.Set("keep_releases", -1) exec.Set("release_name", func() interface{} { @@ -96,7 +97,7 @@ func init() { // Metainfo. // Save metainfo about release. - exec.Remote("echo `date +\"%Y%m%d%H%M%S\"`,{{release_name}} >> .dep/releases") + exec.Remote("echo `%s`,{{release_name}} >> .dep/releases", "date +\"%Y%m%d%H%M%S\"") // Make new release. exec.Remote(fmt.Sprintf("mkdir %s", releasePath)) diff --git a/recipes/deploy/rollback.go b/recipes/deploy/rollback.go index 6077d51..94b520d 100644 --- a/recipes/deploy/rollback.go +++ b/recipes/deploy/rollback.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("rollback", func() { releases := exec.Get("releases_list").Slice() @@ -19,5 +20,4 @@ func init() { exec.Println(fmt.Sprintf("Rollback to `%s` release was successful.", releases[1])) } }).ShortDescription("Rollback to previous release") - } diff --git a/recipes/deploy/shared.go b/recipes/deploy/shared.go index 4edac0f..de2e4e3 100644 --- a/recipes/deploy/shared.go +++ b/recipes/deploy/shared.go @@ -8,6 +8,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("deploy:shared", func() { sharedPath := "{{deploy_path}}/shared" diff --git a/recipes/deploy/stage.go b/recipes/deploy/stage.go index 293b96d..23cbf93 100644 --- a/recipes/deploy/stage.go +++ b/recipes/deploy/stage.go @@ -3,6 +3,7 @@ package deploy import "github.com/go-exec/exec" func init() { + exec := exec.Instance stage := exec.NewArgument("stage", "Provide the running stage") stage.Default = "qa" diff --git a/recipes/deploy/symlink.go b/recipes/deploy/symlink.go index da0af48..ec2ace7 100644 --- a/recipes/deploy/symlink.go +++ b/recipes/deploy/symlink.go @@ -3,6 +3,7 @@ package deploy import "github.com/go-exec/exec" func init() { + exec := exec.Instance exec.Task("deploy:symlink", func() { if exec.Remote("if [[ \"$(man mv 2>/dev/null)\" =~ '--no-target-directory' ]]; then echo 'true'; fi").Bool() { exec.Remote("mv -T {{deploy_path}}/release {{deploy_path}}/current") diff --git a/recipes/deploy/update_code.go b/recipes/deploy/update_code.go index e066e94..c6e8446 100644 --- a/recipes/deploy/update_code.go +++ b/recipes/deploy/update_code.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("deploy:update_code", func() { repository := exec.Get("repository").String() branch := exec.Get("branch").String() diff --git a/recipes/deploy/writable.go b/recipes/deploy/writable.go index 7e67657..c06ac89 100644 --- a/recipes/deploy/writable.go +++ b/recipes/deploy/writable.go @@ -8,6 +8,7 @@ import ( ) func init() { + exec := exec.Instance exec.Task("deploy:writable", func() { dirs := strings.Join(exec.Get("writable_dirs").Slice(), " ") mode := exec.Get("writable_mode").String() diff --git a/recipes/php/defaults.go b/recipes/php/defaults.go index 67b9da9..3608be8 100644 --- a/recipes/php/defaults.go +++ b/recipes/php/defaults.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec.Set("http_user", false) exec.Set("http_group", false) diff --git a/recipes/php/vendors.go b/recipes/php/vendors.go index f214ac3..5fbc5f1 100644 --- a/recipes/php/vendors.go +++ b/recipes/php/vendors.go @@ -3,6 +3,7 @@ package php import "github.com/go-exec/exec" func init() { + exec := exec.Instance exec.Task("deploy:vendors", func() { exec.Remote("cd {{release_path}} && {{env_vars}} {{bin/composer}} {{composer_options}}") }).ShortDescription("Installing vendors") diff --git a/recipes/symfony/recipe.go b/recipes/symfony/recipe.go index bb70396..f542177 100644 --- a/recipes/symfony/recipe.go +++ b/recipes/symfony/recipe.go @@ -6,6 +6,7 @@ import ( ) func init() { + exec := exec.Instance exec. TaskGroup( "deploy", diff --git a/server.go b/server.go index a967aec..39fb8b5 100644 --- a/server.go +++ b/server.go @@ -6,8 +6,8 @@ type server struct { Name string Dsn string Configs map[string]*config - key *string + key *string roles []string sshClient *sshClient } @@ -38,7 +38,7 @@ func (s *server) Key(file string) *server { } func (s *server) GetUser() string { - return s.Dsn[:strings.Index(s.Dsn, "@")-1] + return s.Dsn[:strings.Index(s.Dsn, "@")] } func (s *server) GetHost() string { diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..504a511 --- /dev/null +++ b/server_test.go @@ -0,0 +1,69 @@ +package exec + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestServer_AddRole(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + } + s.AddRole("qa").AddRole("test") + + require.Equal(t, s.roles, []string{"qa", "test"}) +} + +func TestServer_HasRole(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + } + s.AddRole("qa") + + require.True(t, s.HasRole("qa")) + require.False(t, s.HasRole("test")) +} + +func TestServer_Set(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + Configs: make(map[string]*config), + } + s.Set("config", "value") + + require.Contains(t, s.Configs, "config") + require.Equal(t, s.Configs["config"].Value(), "value") +} + +func TestServer_Key(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + Configs: make(map[string]*config), + sshClient: &sshClient{}, + } + s.Key("key") + + require.Contains(t, s.sshClient.keys, "key") +} + +func TestServer_GetUser(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + } + + require.Equal(t, s.GetUser(), "root") +} + +func TestServer_GetHost(t *testing.T) { + s := &server{ + Name: "qa", + Dsn: "root@domain.com", + } + + require.Equal(t, s.GetHost(), "domain.com") +} diff --git a/ssh.go b/ssh.go index 71b09f4..c231c03 100644 --- a/ssh.go +++ b/ssh.go @@ -27,7 +27,6 @@ type sshClient struct { sessOpened bool running bool env string //export FOO="bar"; export BAR="baz"; - color string keys []string authMethod ssh.AuthMethod initAuthMethodOnce sync.Once @@ -66,12 +65,12 @@ func (c *sshClient) parseHost(host string) error { c.user = u.Username } - if strings.Index(c.host, "/") != -1 { + if strings.Contains(c.host, "/") { return errConnect{c.user, c.host, "unexpected slash in the host URL"} } // Add default port, if not set - if strings.Index(c.host, ":") == -1 { + if !strings.Contains(c.host, ":") { c.host += ":22" } @@ -113,6 +112,14 @@ func (c *sshClient) initAuthMethod() { // SSHDialFunc can dial an ssh server and return a client type sshDialFunc func(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) +// WithConnection associate an existing connection +func (c *sshClient) WithConnection(conn *ssh.Client) { + if conn != nil { + c.conn = conn + c.connOpened = true + } +} + // Connect creates SSH connection to a specified host. // It expects the host of the form "[ssh://]host[:port]". func (c *sshClient) Connect(host string) error { @@ -281,7 +288,7 @@ func (c *sshClient) Signal(sig os.Signal) error { // which sounds like something that should be fixed/resolved // upstream in the golang.org/x/crypto/ssh pkg. // https://github.com/golang/go/issues/4115#issuecomment-66070418 - c.remoteStdin.Write([]byte("\x03")) + _, _ = c.remoteStdin.Write([]byte("\x03")) return c.sess.Signal(ssh.SIGINT) default: return fmt.Errorf("%v not supported", sig) diff --git a/ssh_mock/ssh_mock.go b/ssh_mock/ssh_mock.go new file mode 100644 index 0000000..6fb83a1 --- /dev/null +++ b/ssh_mock/ssh_mock.go @@ -0,0 +1,272 @@ +// Modified to support mocking for go-exec + +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd plan9 + +package ssh_mock + +// functional test harness for unix. + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "os/user" + "path/filepath" + "testing" + "text/template" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +const sshd_config = ` +Protocol 2 +HostKey {{.Dir}}/id_rsa +HostKey {{.Dir}}/id_dsa +HostKey {{.Dir}}/id_ecdsa +Pidfile {{.Dir}}/sshd.pid +#UsePrivilegeSeparation no +KeyRegenerationInterval 3600 +ServerKeyBits 768 +SyslogFacility AUTH +#LogLevel DEBUG2 +LoginGraceTime 120 +PermitRootLogin no +StrictModes no +RSAAuthentication yes +PubkeyAuthentication yes +AuthorizedKeysFile {{.Dir}}/authorized_keys +TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub +IgnoreRhosts yes +RhostsRSAAuthentication no +HostbasedAuthentication no +PubkeyAcceptedKeyTypes=* +` + +var configTmpl = template.Must(template.New("").Parse(sshd_config)) + +type Server struct { + t *testing.T + cleanup func() // executed during Shutdown + configfile string + cmd *exec.Cmd + output bytes.Buffer // holds stderr from sshd process + + // Client half of the network connection. + clientConn net.Conn +} + +func username() string { + var username string + if user, err := user.Current(); err == nil { + username = user.Username + } else { + // user.Current() currently requires cgo. If an error is + // returned attempt to get the username from the environment. + log.Printf("user.Current: %v; falling back on $USER", err) + username = os.Getenv("USER") + } + if username == "" { + panic("Unable to get username") + } + return username +} + +type storedHostKey struct { + // keys map from an algorithm string to binary key data. + keys map[string][]byte + + // checkCount counts the Check calls. Used for testing + // rekeying. + checkCount int +} + +func (k *storedHostKey) Add(key ssh.PublicKey) { + if k.keys == nil { + k.keys = map[string][]byte{} + } + k.keys[key.Type()] = key.Marshal() +} + +func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { + k.checkCount++ + algo := key.Type() + + if k.keys == nil || !bytes.Equal(key.Marshal(), k.keys[algo]) { + return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) + } + return nil +} + +func hostKeyDB() *storedHostKey { + keyChecker := &storedHostKey{} + keyChecker.Add(testPublicKeys["ecdsa"]) + keyChecker.Add(testPublicKeys["rsa"]) + keyChecker.Add(testPublicKeys["dsa"]) + return keyChecker +} + +func ClientConfig() *ssh.ClientConfig { + config := &ssh.ClientConfig{ + User: username(), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(testSigners["user"]), + }, + HostKeyCallback: hostKeyDB().Check, + } + return config +} + +// unixConnection creates two halves of a connected net.UnixConn. It +// is used for connecting the Go SSH client with sshd without opening +// ports. +func unixConnection() (*net.UnixConn, *net.UnixConn, error) { + dir, err := ioutil.TempDir("", "unixConnection") + if err != nil { + return nil, nil, err + } + defer os.Remove(dir) + + addr := filepath.Join(dir, "ssh") + listener, err := net.Listen("unix", addr) + if err != nil { + return nil, nil, err + } + defer listener.Close() + c1, err := net.Dial("unix", addr) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1.(*net.UnixConn), c2.(*net.UnixConn), nil +} + +func (s *Server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { + sshd, err := exec.LookPath("sshd") + if err != nil { + s.t.Skipf("skipping test: %v", err) + } + + c1, c2, err := unixConnection() + if err != nil { + s.t.Fatalf("unixConnection: %v", err) + } + + s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") + f, err := c2.File() + if err != nil { + s.t.Fatalf("UnixConn.File: %v", err) + } + defer f.Close() + s.cmd.Stdin = f + s.cmd.Stdout = f + s.cmd.Stderr = &s.output + if err := s.cmd.Start(); err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("s.cmd.Start: %v", err) + } + s.clientConn = c1 + conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) + if err != nil { + return nil, err + } + return ssh.NewClient(conn, chans, reqs), nil +} + +func (s *Server) Dial(config *ssh.ClientConfig) *ssh.Client { + conn, err := s.TryDial(config) + if err != nil { + s.t.Fail() + s.Shutdown() + s.t.Fatalf("ssh.Client: %v", err) + } + return conn +} + +func (s *Server) Shutdown() { + if s.cmd != nil && s.cmd.Process != nil { + // Don't check for errors; if it fails it's most + // likely "os: process already finished", and we don't + // care about that. Use os.Interrupt, so child + // processes are killed too. + _ = s.cmd.Process.Signal(os.Interrupt) + _ = s.cmd.Wait() + } + if s.t.Failed() { + // log any output from sshd process + s.t.Logf("sshd: %s", s.output.String()) + } + s.cleanup() +} + +func writeFile(path string, contents []byte) { + f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + panic(err) + } + defer func() { + _ = f.Close() + }() + if _, err := f.Write(contents); err != nil { + panic(err) + } +} + +// NewServer returns a new mock ssh server. +func NewServer(t *testing.T) *Server { + if testing.Short() { + t.Skip("skipping test due to -short") + } + dir, err := ioutil.TempDir("", "sshtest") + if err != nil { + t.Fatal(err) + } + f, err := os.Create(filepath.Join(dir, "sshd_config")) + if err != nil { + t.Fatal(err) + } + err = configTmpl.Execute(f, map[string]string{ + "Dir": dir, + }) + if err != nil { + t.Fatal(err) + } + _ = f.Close() + + for k, v := range testdata.PEMBytes { + filename := "id_" + k + writeFile(filepath.Join(dir, filename), v) + writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + + var authkeys bytes.Buffer + for k := range testdata.PEMBytes { + authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) + } + writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes()) + + return &Server{ + t: t, + configfile: f.Name(), + cleanup: func() { + if err := os.RemoveAll(dir); err != nil { + t.Error(err) + } + }, + } +} diff --git a/ssh_mock/testdata.go b/ssh_mock/testdata.go new file mode 100644 index 0000000..54239be --- /dev/null +++ b/ssh_mock/testdata.go @@ -0,0 +1,66 @@ +// Modified to support mocking for go-exec + +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh_mock + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]ssh.Signer + testPublicKeys map[string]ssh.PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]ssh.Signer, n) + testPublicKeys = make(map[string]ssh.PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &ssh.Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + _ = testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/task.go b/task.go index 183d2f4..a85ab22 100644 --- a/task.go +++ b/task.go @@ -19,6 +19,7 @@ type task struct { Options map[string]*Option Arguments map[string]*Argument + exec *Exec run taskFunction subtasks map[string]*task shortDescription string @@ -121,7 +122,7 @@ func (t *task) getOrderedArguments() sortArguments { // printhelp prints the return value of help to the standard output. func (t *task) printhelp(taskName string) { - fmt.Printf(t.help(taskName)) + fmt.Print(t.help(taskName)) } // usageString returns a short string containing the syntax of this command. @@ -216,7 +217,7 @@ func (t *task) execute(taskName string, cmdArgs []string) error { } // Executing the onStart task - onStart() + t.exec.onStart() if len(t.before) > 0 { for _, tb := range t.before { @@ -234,7 +235,7 @@ func (t *task) execute(taskName string, cmdArgs []string) error { } // Executing the onEnd task - onEnd() + t.exec.onEnd() // Execute it only once if requested if t.once && !t.executedOnce { diff --git a/utils.go b/utils.go index 2767f9d..c616357 100644 --- a/utils.go +++ b/utils.go @@ -16,27 +16,27 @@ import ( ) // Cd is a remote helper function that runs a `cd` before a command -func Cd(path string) { - command := "cd " + Parse(path) - color.Green("[%s] %s %s", ServerContext.Name, color.GreenString(">"), command) - ServerContext.sshClient.env = command + "; " +func (e *Exec) Cd(path string) { + command := "cd " + e.Parse(path) + color.Green("[%s] %s %s", e.ServerContext.Name, color.GreenString(">"), command) + e.ServerContext.sshClient.env = command + "; " } // CommandExist checks if a remote command exists on server -func CommandExist(command string) bool { - return Remote("if hash %s 2>/dev/null; then echo 'true'; fi", command).Bool() +func (e *Exec) CommandExist(command string) bool { + return e.Remote("if hash %s 2>/dev/null; then echo 'true'; fi", command).Bool() } // Parse parses {{var}} with Get(var) -func Parse(text string) string { +func (e *Exec) Parse(text string) string { re := regexp.MustCompile(`\{\{\s*([\w\.\/]+)\s*\}\}`) if !re.MatchString(text) { return text } return re.ReplaceAllStringFunc(text, func(str string) string { - name := strings.TrimRight(strings.TrimLeft(str, "{{"), "}}") - if Has(name) { - return Parse(Get(name).String()) + name := strings.TrimSuffix(strings.TrimPrefix(str, "{{"), "}}") + if e.Has(name) { + return e.Parse(e.Get(name).String()) } return str }) @@ -44,39 +44,39 @@ func Parse(text string) string { // RunIfNoBinary runs a remote command if a binary is not found // command can be an array of string commands or one a string command -func RunIfNoBinary(binary string, command interface{}) (o output) { - return Remote("if [ ! -e \"`which %s`\" ]; then %s; fi", binary, commandToString(command)) +func (e *Exec) RunIfNoBinary(binary string, command interface{}) (o Output) { + return e.Remote("if [ ! -e \"`which %s`\" ]; then %s; fi", binary, commandToString(command)) } // RunIfNoBinaries runs multiple RunIfNoBinary -func RunIfNoBinaries(config map[string]interface{}) { +func (e *Exec) RunIfNoBinaries(config map[string]interface{}) { for binary, command := range config { - RunIfNoBinary(binary, command) + e.RunIfNoBinary(binary, command) } } // RunIf runs a remote command if condition is true // command can be an array of string commands or one a string command -func RunIf(condition string, command interface{}) (o output) { - return Remote("if %s; then %s; fi", condition, commandToString(command)) +func (e *Exec) RunIf(condition string, command interface{}) (o Output) { + return e.Remote("if %s; then %s; fi", condition, commandToString(command)) } // RunIfs runs multiple RunIf -func RunIfs(config map[string]interface{}) { +func (e *Exec) RunIfs(config map[string]interface{}) { for condition, command := range config { - RunIf(condition, command) + e.RunIf(condition, command) } } // UploadFileSudo uploads a local file to a remote file with sudo -func UploadFileSudo(source, destination string) { +func (e *Exec) UploadFileSudo(source, destination string) { tempFile := "/tmp/" + uuid.NewV4().String() - Upload(source, tempFile) - Remote("sudo mv %s %s", tempFile, destination) + e.Upload(source, tempFile) + e.Remote("sudo mv %s %s", tempFile, destination) } // UploadTemplateFileSudo parses a local template file with context, and uploads it to a remote file with sudo -func UploadTemplateFileSudo(source, destination string, context interface{}) { +func (e *Exec) UploadTemplateFileSudo(source, destination string, context interface{}) { tempFile := "/tmp/" + uuid.NewV4().String() t, err := template.New(path.Base(source)).ParseFiles(source) @@ -91,26 +91,26 @@ func UploadTemplateFileSudo(source, destination string, context interface{}) { if err := ioutil.WriteFile(tempFile, tpl.Bytes(), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - Upload(tempFile, tempFile) - Local("rm %s", tempFile) - Remote("sudo mv %s %s", tempFile, destination) + e.Upload(tempFile, tempFile) + e.Local("rm %s", tempFile) + e.Remote("sudo mv %s %s", tempFile, destination) } } // UploadTemplateStringSudo uploads a string content to a remote file with sudo -func UploadTemplateStringSudo(content, destination string) { +func (e *Exec) UploadTemplateStringSudo(content, destination string) { tempFile := "/tmp/" + uuid.NewV4().String() if err := ioutil.WriteFile(tempFile, []byte(content), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - Upload(tempFile, tempFile) - Local("rm %s", tempFile) - Remote("sudo mv %s %s", tempFile, destination) + e.Upload(tempFile, tempFile) + e.Local("rm %s", tempFile) + e.Remote("sudo mv %s %s", tempFile, destination) } } // LocalTemplateFile parses a local template file with context, and moves it to a destination -func LocalTemplateFile(source, destination string, context interface{}) { +func (e *Exec) LocalTemplateFile(source, destination string, context interface{}) { tempFile := "/tmp/" + uuid.NewV4().String() t, err := template.New(path.Base(source)).ParseFiles(source) @@ -125,12 +125,12 @@ func LocalTemplateFile(source, destination string, context interface{}) { if err := ioutil.WriteFile(tempFile, tpl.Bytes(), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - Local("mv %s %s", tempFile, destination) + e.Local("mv %s %s", tempFile, destination) } } // CompileLocalTemplateFile parses a local source file template with context and returns it -func CompileLocalTemplateFile(source string, context interface{}) string { +func (e *Exec) CompileLocalTemplateFile(source string, context interface{}) string { t, err := template.New(path.Base(source)).ParseFiles(source) if err != nil { color.Red("[%s] %s %s", "local", "<", err) @@ -143,7 +143,7 @@ func CompileLocalTemplateFile(source string, context interface{}) string { } // CompileLocalTemplateString parses a local source string template with context and returns it -func CompileLocalTemplateString(source string, context interface{}) string { +func (e *Exec) CompileLocalTemplateString(source string, context interface{}) string { t, err := template.New(uuid.NewV4().String()).Parse(source) if err != nil { color.Red("[%s] %s %s", "local", "<", err) @@ -156,71 +156,71 @@ func CompileLocalTemplateString(source string, context interface{}) string { } // ReplaceInRemoteFile replaces a search string with a replace string, in a remote file -func ReplaceInRemoteFile(file, search, replace string) { +func (e *Exec) ReplaceInRemoteFile(file, search, replace string) { tempFile := "/tmp/" + uuid.NewV4().String() - Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, ServerContext.GetUser(), tempFile) - Download(tempFile, tempFile) + e.Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, e.ServerContext.GetUser(), tempFile) + e.Download(tempFile, tempFile) if tempFileContent, err := ioutil.ReadFile(tempFile); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - tempFileContent := strings.Replace(string(tempFileContent), search, Parse(replace), -1) + tempFileContent := strings.Replace(string(tempFileContent), search, e.Parse(replace), -1) if err := ioutil.WriteFile(tempFile, []byte(tempFileContent), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - UploadFileSudo(tempFile, file) - Remote("sudo rm -rf %s", tempFile) - Local("rm -rf %s", tempFile) + e.UploadFileSudo(tempFile, file) + e.Remote("sudo rm -rf %s", tempFile) + e.Local("rm -rf %s", tempFile) } } } // AddInRemoteFile appends a text string to a remote file -func AddInRemoteFile(text, file string) { +func (e *Exec) AddInRemoteFile(text, file string) { tempFile := "/tmp/" + uuid.NewV4().String() - Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, ServerContext.GetUser(), tempFile) - Download(tempFile, tempFile) + e.Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, e.ServerContext.GetUser(), tempFile) + e.Download(tempFile, tempFile) if tempFileContent, err := ioutil.ReadFile(tempFile); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - tempFileContent := string(tempFileContent) + Parse(text) + tempFileContent := string(tempFileContent) + e.Parse(text) if err := ioutil.WriteFile(tempFile, []byte(tempFileContent), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - UploadFileSudo(tempFile, file) - Remote("sudo rm -rf %s", tempFile) - Local("rm -rf %s", tempFile) + e.UploadFileSudo(tempFile, file) + e.Remote("sudo rm -rf %s", tempFile) + e.Local("rm -rf %s", tempFile) } } } // RemoveFromRemoteFile cuts out a text string from remote file -func RemoveFromRemoteFile(text, file string) { +func (e *Exec) RemoveFromRemoteFile(text, file string) { tempFile := "/tmp/" + uuid.NewV4().String() - Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, ServerContext.GetUser(), tempFile) - Download(tempFile, tempFile) + e.Remote("sudo cp %s %s ; sudo chown %s %s", file, tempFile, e.ServerContext.GetUser(), tempFile) + e.Download(tempFile, tempFile) if tempFileContent, err := ioutil.ReadFile(tempFile); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - tempFileContent := strings.Replace(string(tempFileContent), Parse(text), "", -1) + tempFileContent := strings.Replace(string(tempFileContent), e.Parse(text), "", -1) if err := ioutil.WriteFile(tempFile, []byte(tempFileContent), os.FileMode(0644)); err != nil { color.Red("[%s] %s %s", "local", "<", err) } else { - UploadFileSudo(tempFile, file) - Remote("sudo rm -rf %s", tempFile) - Local("rm -rf %s", tempFile) + e.UploadFileSudo(tempFile, file) + e.Remote("sudo rm -rf %s", tempFile) + e.Local("rm -rf %s", tempFile) } } } // IsInRemoteFile return true if text is found in a remote file -func IsInRemoteFile(text, file string) bool { +func (e *Exec) IsInRemoteFile(text, file string) bool { text = strings.Trim(text, " ") - return Remote("if [ \"`sudo cat %s | grep '%s'`\" ]; then echo 'true'; fi", file, text).Bool() + return e.Remote("if [ \"`sudo cat %s | grep '%s'`\" ]; then echo 'true'; fi", file, text).Bool() } // Ask asks a question and waits for an answer // first item from attributes is set as default value, which is optional -func Ask(question string, attributes ...string) string { +func (e *Exec) Ask(question string, attributes ...string) string { scanner := bufio.NewScanner(os.Stdin) var defaultResponse string @@ -243,7 +243,7 @@ func Ask(question string, attributes ...string) string { // AskWithConfirmation asks a confirmation question and waits for an y/n answer // first item from attributes is set as default value, which is optional -func AskWithConfirmation(question string, attributes ...bool) bool { +func (e *Exec) AskWithConfirmation(question string, attributes ...bool) bool { scanner := bufio.NewScanner(os.Stdin) var defaultResponse bool @@ -290,7 +290,7 @@ First item from attributes must be a map with default and choices keys and strin } ``` */ -func AskWithChoices(question string, attributes ...map[string]interface{}) (responses []string) { +func (e *Exec) AskWithChoices(question string, attributes ...map[string]interface{}) (responses []string) { scanner := bufio.NewScanner(os.Stdin) var ( @@ -359,10 +359,8 @@ func commandToString(run interface{}) string { switch rt.Kind() { case reflect.Slice: runS = strings.Join(run.([]string), " ; ") - break case reflect.String: runS = run.(string) - break } return runS }