diff --git a/cmd/stack.go b/cmd/stack.go index c082606754..f43d3a4dbc 100644 --- a/cmd/stack.go +++ b/cmd/stack.go @@ -6,17 +6,12 @@ package cmd import ( "fmt" - "os" - "path/filepath" "strings" "github.com/jedib0t/go-pretty/table" "github.com/pkg/errors" "github.com/spf13/cobra" - "github.com/elastic/go-sysinfo" - "github.com/elastic/go-sysinfo/types" - "github.com/elastic/elastic-package/internal/cobraext" "github.com/elastic/elastic-package/internal/common" "github.com/elastic/elastic-package/internal/install" @@ -209,11 +204,13 @@ func setupStackCommand() *cobraext.Command { } if shellName == cobraext.ShellInitShellDetect { - shellName, err = detectShell() + shellName, err = stack.AutodetectedShell() if err != nil { return fmt.Errorf("cannot detect parent shell from current process: %w", err) } fmt.Fprintf(cmd.OutOrStderr(), "Detected shell: %s\n", shellName) + } else { + stack.SelectShell(shellName) } profile, err := profile.LoadProfile(profileName) @@ -221,7 +218,7 @@ func setupStackCommand() *cobraext.Command { return errors.Wrap(err, "error loading profile") } - shellCode, err := stack.ShellInit(profile, shellName) + shellCode, err := stack.ShellInit(profile) if err != nil { return errors.Wrap(err, "shellinit failed") } @@ -352,43 +349,3 @@ func printStatus(cmd *cobra.Command, servicesStatus []stack.ServiceStatus) { t.SetStyle(table.StyleRounded) cmd.Println(t.Render()) } - -func getParentInfo(ppid int) (types.ProcessInfo, error) { - parent, err := sysinfo.Process(ppid) - if err != nil { - return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for process %d: %w", ppid, err) - } - - parentInfo, err := parent.Info() - if err != nil { - return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for parent of process %d: %w", ppid, err) - } - - return parentInfo, nil -} - -func getShellName(exe string) string { - shell := filepath.Base(exe) - // NOTE: remove .exe extension from executable names present in Windows - shell = strings.TrimSuffix(shell, ".exe") - return shell -} - -func detectShell() (string, error) { - ppid := os.Getppid() - parentInfo, err := getParentInfo(ppid) - if err != nil { - return "", err - } - - shell := getShellName(parentInfo.Exe) - if shell == "go" { - parentParentInfo, err := getParentInfo(parentInfo.PPID) - if err != nil { - return "", fmt.Errorf("cannot retrieve parent parent info: %w", err) - } - return getShellName(parentParentInfo.Exe), nil - } - - return shell, nil -} diff --git a/cmd/stack_test.go b/cmd/stack_test.go index 133c8a9ec6..af764b8eea 100644 --- a/cmd/stack_test.go +++ b/cmd/stack_test.go @@ -6,12 +6,8 @@ package cmd import ( "fmt" - "os" - "reflect" "testing" - "github.com/elastic/go-sysinfo" - "github.com/elastic/go-sysinfo/types" "github.com/stretchr/testify/require" ) @@ -51,63 +47,3 @@ func TestValidateServicesFlag(t *testing.T) { } } - -func Test_getShellName(t *testing.T) { - type args struct { - exe string - } - tests := []struct { - name string - args args - want string - }{ - {"linux exec", args{exe: "bash"}, "bash"}, - {"windows exec", args{exe: "cmd.exe"}, "cmd"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getShellName(tt.args.exe); got != tt.want { - t.Errorf("getShellName() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_getParentInfo(t *testing.T) { - ppid := os.Getppid() - parent, err := sysinfo.Process(ppid) - if err != nil { - panic(err) - } - info, err := parent.Info() - if err != nil { - panic(err) - } - - type args struct { - ppid int - } - tests := []struct { - name string - args args - want types.ProcessInfo - wantErr bool - }{ - // TODO: Add test cases. - {"test parent", args{ppid}, info, false}, - {"bogus ppid", args{999999}, types.ProcessInfo{}, true}, - {"no parent", args{1}, types.ProcessInfo{}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getParentInfo(tt.args.ppid) - if (err != nil) != tt.wantErr { - t.Errorf("getParentInfo() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("getParentInfo() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/stack/errors.go b/internal/stack/errors.go index fa138f709f..a7840c79a6 100644 --- a/internal/stack/errors.go +++ b/internal/stack/errors.go @@ -9,5 +9,5 @@ import "fmt" // UndefinedEnvError formats an error reported for undefined variable. func UndefinedEnvError(envName string) error { return fmt.Errorf("undefined environment variable: %s. If you have started the Elastic stack using the elastic-package tool, "+ - `please load stack environment variables using 'eval "$(elastic-package stack shellinit)"' or set their values manually`, envName) + `please load stack environment variables using '%s' or set their values manually`, envName, helpText(shellType)) } diff --git a/internal/stack/shellinit.go b/internal/stack/shellinit.go index f8d0240f37..97e4adff77 100644 --- a/internal/stack/shellinit.go +++ b/internal/stack/shellinit.go @@ -7,8 +7,13 @@ package stack import ( "errors" "fmt" + "os" + "path/filepath" "strings" + "github.com/elastic/go-sysinfo" + "github.com/elastic/go-sysinfo/types" + "github.com/elastic/elastic-package/internal/environment" "github.com/elastic/elastic-package/internal/profile" ) @@ -22,8 +27,26 @@ var ( CACertificateEnv = environment.WithElasticPackagePrefix("CA_CERT") ) +var shellType string +var shellDetectError error + +func init() { + shellType, shellDetectError = detectShell() +} + +// SelectShell selects the shell to use. +func SelectShell(shell string) { + shellType = shell + shellDetectError = nil +} + +// AutodetectedShell returns an error if shell could not be detected. +func AutodetectedShell() (string, error) { + return shellType, shellDetectError +} + // ShellInit method exposes environment variables that can be used for testing purposes. -func ShellInit(elasticStackProfile *profile.Profile, shellType string) (string, error) { +func ShellInit(elasticStackProfile *profile.Profile) (string, error) { config, err := StackInitConfig(elasticStackProfile) if err != nil { return "", nil @@ -61,10 +84,18 @@ set -x %s %s; set -x %s %s; set -x %s %s; ` + + // PowerShell init code. + // Output to be evaluated with `elastic-package stack shellinit | Invoke-Expression + powershellTemplate = `$Env:%s="%s"; +$Env:%s="%s"; +$Env:%s="%s"; +$Env:%s="%s"; +$Env:%s="%s";` ) // availableShellTypes list all available values for s in initTemplate -var availableShellTypes = []string{"bash", "dash", "fish", "sh", "zsh"} +var availableShellTypes = []string{"bash", "dash", "fish", "sh", "zsh", "pwsh", "powershell"} // InitTemplate returns code templates for shell initialization func initTemplate(s string) (string, error) { @@ -73,7 +104,59 @@ func initTemplate(s string) (string, error) { return posixTemplate, nil case "fish": return fishTemplate, nil + case "pwsh", "powershell": + return powershellTemplate, nil default: return "", errors.New("shell type is unknown, should be one of " + strings.Join(availableShellTypes, ", ")) } } + +// helpText returns the instrutions about how to set environment variables with shellinit +func helpText(shell string) string { + switch shell { + case "pwsh", "powershell": + return `elastic-package stack shellinit | Invoke-Expression` + default: + return `eval "$(elastic-package stack shellinit)"` + } +} + +func getShellName(exe string) string { + shell := filepath.Base(exe) + // NOTE: remove .exe extension from executable names present in Windows + shell = strings.TrimSuffix(shell, ".exe") + return shell +} + +func detectShell() (string, error) { + ppid := os.Getppid() + parentInfo, err := getParentInfo(ppid) + if err != nil { + return "", err + } + + shell := getShellName(parentInfo.Exe) + if shell == "go" { + parentParentInfo, err := getParentInfo(parentInfo.PPID) + if err != nil { + return "", fmt.Errorf("cannot retrieve parent parent info: %w", err) + } + return getShellName(parentParentInfo.Exe), nil + } + + return shell, nil +} + +func getParentInfo(ppid int) (types.ProcessInfo, error) { + parent, err := sysinfo.Process(ppid) + if err != nil { + return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for process %d: %w", ppid, err) + } + + parentInfo, err := parent.Info() + if err != nil { + return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for parent of process %d: %w", ppid, err) + } + + return parentInfo, nil +} diff --git a/internal/stack/shellinit_internal_test.go b/internal/stack/shellinit_internal_test.go index e0fd8a2a30..a36ccc2ae1 100644 --- a/internal/stack/shellinit_internal_test.go +++ b/internal/stack/shellinit_internal_test.go @@ -5,9 +5,13 @@ package stack import ( + "os" + "reflect" "strings" "testing" + "github.com/elastic/go-sysinfo" + "github.com/elastic/go-sysinfo/types" "gotest.tools/v3/assert" ) @@ -38,3 +42,63 @@ func TestCodeTemplate_wrongInput(t *testing.T) { _, err := initTemplate("invalid shell type") assert.Error(t, err, "shell type is unknown, should be one of "+strings.Join(availableShellTypes, ", ")) } + +func Test_getShellName(t *testing.T) { + type args struct { + exe string + } + tests := []struct { + name string + args args + want string + }{ + {"linux exec", args{exe: "bash"}, "bash"}, + {"windows exec", args{exe: "cmd.exe"}, "cmd"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getShellName(tt.args.exe); got != tt.want { + t.Errorf("getShellName() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getParentInfo(t *testing.T) { + ppid := os.Getppid() + parent, err := sysinfo.Process(ppid) + if err != nil { + panic(err) + } + info, err := parent.Info() + if err != nil { + panic(err) + } + + type args struct { + ppid int + } + tests := []struct { + name string + args args + want types.ProcessInfo + wantErr bool + }{ + // TODO: Add test cases. + {"test parent", args{ppid}, info, false}, + {"bogus ppid", args{999999}, types.ProcessInfo{}, true}, + {"no parent", args{1}, types.ProcessInfo{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getParentInfo(tt.args.ppid) + if (err != nil) != tt.wantErr { + t.Errorf("getParentInfo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("getParentInfo() = %v, want %v", got, tt.want) + } + }) + } +}