Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 4 additions & 47 deletions cmd/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@ package cmd

import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/jedib0t/go-pretty/table"
"github.com/pkg/errors"
"github.com/spf13/cobra"

"github.com/elastic/go-sysinfo"
"github.com/elastic/go-sysinfo/types"

"github.com/elastic/elastic-package/internal/cobraext"
"github.com/elastic/elastic-package/internal/common"
"github.com/elastic/elastic-package/internal/install"
Expand Down Expand Up @@ -209,19 +204,21 @@ func setupStackCommand() *cobraext.Command {
}

if shellName == cobraext.ShellInitShellDetect {
shellName, err = detectShell()
shellName, err = stack.AutodetectedShell()
if err != nil {
return fmt.Errorf("cannot detect parent shell from current process: %w", err)
}
fmt.Fprintf(cmd.OutOrStderr(), "Detected shell: %s\n", shellName)
} else {
stack.SelectShell(shellName)
}

profile, err := profile.LoadProfile(profileName)
if err != nil {
return errors.Wrap(err, "error loading profile")
}

shellCode, err := stack.ShellInit(profile, shellName)
shellCode, err := stack.ShellInit(profile)
if err != nil {
return errors.Wrap(err, "shellinit failed")
}
Expand Down Expand Up @@ -352,43 +349,3 @@ func printStatus(cmd *cobra.Command, servicesStatus []stack.ServiceStatus) {
t.SetStyle(table.StyleRounded)
cmd.Println(t.Render())
}

func getParentInfo(ppid int) (types.ProcessInfo, error) {
parent, err := sysinfo.Process(ppid)
if err != nil {
return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for process %d: %w", ppid, err)
}

parentInfo, err := parent.Info()
if err != nil {
return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for parent of process %d: %w", ppid, err)
}

return parentInfo, nil
}

func getShellName(exe string) string {
shell := filepath.Base(exe)
// NOTE: remove .exe extension from executable names present in Windows
shell = strings.TrimSuffix(shell, ".exe")
return shell
}

func detectShell() (string, error) {
ppid := os.Getppid()
parentInfo, err := getParentInfo(ppid)
if err != nil {
return "", err
}

shell := getShellName(parentInfo.Exe)
if shell == "go" {
parentParentInfo, err := getParentInfo(parentInfo.PPID)
if err != nil {
return "", fmt.Errorf("cannot retrieve parent parent info: %w", err)
}
return getShellName(parentParentInfo.Exe), nil
}

return shell, nil
}
64 changes: 0 additions & 64 deletions cmd/stack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@ package cmd

import (
"fmt"
"os"
"reflect"
"testing"

"github.com/elastic/go-sysinfo"
"github.com/elastic/go-sysinfo/types"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -51,63 +47,3 @@ func TestValidateServicesFlag(t *testing.T) {
}

}

func Test_getShellName(t *testing.T) {
type args struct {
exe string
}
tests := []struct {
name string
args args
want string
}{
{"linux exec", args{exe: "bash"}, "bash"},
{"windows exec", args{exe: "cmd.exe"}, "cmd"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getShellName(tt.args.exe); got != tt.want {
t.Errorf("getShellName() = %v, want %v", got, tt.want)
}
})
}
}

func Test_getParentInfo(t *testing.T) {
ppid := os.Getppid()
parent, err := sysinfo.Process(ppid)
if err != nil {
panic(err)
}
info, err := parent.Info()
if err != nil {
panic(err)
}

type args struct {
ppid int
}
tests := []struct {
name string
args args
want types.ProcessInfo
wantErr bool
}{
// TODO: Add test cases.
{"test parent", args{ppid}, info, false},
{"bogus ppid", args{999999}, types.ProcessInfo{}, true},
{"no parent", args{1}, types.ProcessInfo{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getParentInfo(tt.args.ppid)
if (err != nil) != tt.wantErr {
t.Errorf("getParentInfo() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("getParentInfo() = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion internal/stack/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ import "fmt"
// UndefinedEnvError formats an error reported for undefined variable.
func UndefinedEnvError(envName string) error {
return fmt.Errorf("undefined environment variable: %s. If you have started the Elastic stack using the elastic-package tool, "+
`please load stack environment variables using 'eval "$(elastic-package stack shellinit)"' or set their values manually`, envName)
`please load stack environment variables using '%s' or set their values manually`, envName, helpText(shellType))
}
87 changes: 85 additions & 2 deletions internal/stack/shellinit.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ package stack
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/elastic/go-sysinfo"
"github.com/elastic/go-sysinfo/types"

"github.com/elastic/elastic-package/internal/environment"
"github.com/elastic/elastic-package/internal/profile"
)
Expand All @@ -22,8 +27,26 @@ var (
CACertificateEnv = environment.WithElasticPackagePrefix("CA_CERT")
)

var shellType string
var shellDetectError error

func init() {
shellType, shellDetectError = detectShell()
}

// SelectShell selects the shell to use.
func SelectShell(shell string) {
shellType = shell
shellDetectError = nil
}

// AutodetectedShell returns an error if shell could not be detected.
func AutodetectedShell() (string, error) {
return shellType, shellDetectError
}

// ShellInit method exposes environment variables that can be used for testing purposes.
func ShellInit(elasticStackProfile *profile.Profile, shellType string) (string, error) {
func ShellInit(elasticStackProfile *profile.Profile) (string, error) {
config, err := StackInitConfig(elasticStackProfile)
if err != nil {
return "", nil
Expand Down Expand Up @@ -61,10 +84,18 @@ set -x %s %s;
set -x %s %s;
set -x %s %s;
`

// PowerShell init code.
// Output to be evaluated with `elastic-package stack shellinit | Invoke-Expression
powershellTemplate = `$Env:%s="%s";
$Env:%s="%s";
$Env:%s="%s";
$Env:%s="%s";
$Env:%s="%s";`
)

// availableShellTypes list all available values for s in initTemplate
var availableShellTypes = []string{"bash", "dash", "fish", "sh", "zsh"}
var availableShellTypes = []string{"bash", "dash", "fish", "sh", "zsh", "pwsh", "powershell"}

// InitTemplate returns code templates for shell initialization
func initTemplate(s string) (string, error) {
Expand All @@ -73,7 +104,59 @@ func initTemplate(s string) (string, error) {
return posixTemplate, nil
case "fish":
return fishTemplate, nil
case "pwsh", "powershell":
return powershellTemplate, nil
default:
return "", errors.New("shell type is unknown, should be one of " + strings.Join(availableShellTypes, ", "))
}
}

// helpText returns the instrutions about how to set environment variables with shellinit
func helpText(shell string) string {
switch shell {
case "pwsh", "powershell":
return `elastic-package stack shellinit | Invoke-Expression`
default:
return `eval "$(elastic-package stack shellinit)"`
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@endorama is this syntax also used on fish?

}
}

func getShellName(exe string) string {
shell := filepath.Base(exe)
// NOTE: remove .exe extension from executable names present in Windows
shell = strings.TrimSuffix(shell, ".exe")
return shell
}

func detectShell() (string, error) {
ppid := os.Getppid()
parentInfo, err := getParentInfo(ppid)
if err != nil {
return "", err
}

shell := getShellName(parentInfo.Exe)
if shell == "go" {
parentParentInfo, err := getParentInfo(parentInfo.PPID)
if err != nil {
return "", fmt.Errorf("cannot retrieve parent parent info: %w", err)
}
return getShellName(parentParentInfo.Exe), nil
}

return shell, nil
}

func getParentInfo(ppid int) (types.ProcessInfo, error) {
parent, err := sysinfo.Process(ppid)
if err != nil {
return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for process %d: %w", ppid, err)
}

parentInfo, err := parent.Info()
if err != nil {
return types.ProcessInfo{}, fmt.Errorf("cannot retrieve information for parent of process %d: %w", ppid, err)
}

return parentInfo, nil
}
64 changes: 64 additions & 0 deletions internal/stack/shellinit_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
package stack

import (
"os"
"reflect"
"strings"
"testing"

"github.com/elastic/go-sysinfo"
"github.com/elastic/go-sysinfo/types"
"gotest.tools/v3/assert"
)

Expand Down Expand Up @@ -38,3 +42,63 @@ func TestCodeTemplate_wrongInput(t *testing.T) {
_, err := initTemplate("invalid shell type")
assert.Error(t, err, "shell type is unknown, should be one of "+strings.Join(availableShellTypes, ", "))
}

func Test_getShellName(t *testing.T) {
type args struct {
exe string
}
tests := []struct {
name string
args args
want string
}{
{"linux exec", args{exe: "bash"}, "bash"},
{"windows exec", args{exe: "cmd.exe"}, "cmd"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getShellName(tt.args.exe); got != tt.want {
t.Errorf("getShellName() = %v, want %v", got, tt.want)
}
})
}
}

func Test_getParentInfo(t *testing.T) {
ppid := os.Getppid()
parent, err := sysinfo.Process(ppid)
if err != nil {
panic(err)
}
info, err := parent.Info()
if err != nil {
panic(err)
}

type args struct {
ppid int
}
tests := []struct {
name string
args args
want types.ProcessInfo
wantErr bool
}{
// TODO: Add test cases.
{"test parent", args{ppid}, info, false},
{"bogus ppid", args{999999}, types.ProcessInfo{}, true},
{"no parent", args{1}, types.ProcessInfo{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getParentInfo(tt.args.ppid)
if (err != nil) != tt.wantErr {
t.Errorf("getParentInfo() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("getParentInfo() = %v, want %v", got, tt.want)
}
})
}
}