diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 95e7c1d3..36cb9548 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -7,6 +7,11 @@ jobs: build: runs-on: ubuntu-latest steps: + - name: install test dependencies + run: | + sudo apt-get update + sudo apt-get install expect + - uses: actions/checkout@v2 - name: Set up Go @@ -32,3 +37,7 @@ jobs: - name: Test run: go test -v ./... + + - name: Run integration tests + run: make -C test test + diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0bf7bd1c --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +test/rigtest +test/footloose.yaml +test/Library +test/.ssh diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go new file mode 100644 index 00000000..d1deb8ba --- /dev/null +++ b/cmd/rigtest/rigtest.go @@ -0,0 +1,163 @@ +package main + +import ( + "flag" + "fmt" + goos "os" + "strconv" + "strings" + "time" + + "github.com/k0sproject/rig" + "github.com/k0sproject/rig/exec" + "github.com/k0sproject/rig/os" + "github.com/k0sproject/rig/os/registry" + _ "github.com/k0sproject/rig/os/support" + "github.com/kevinburke/ssh_config" +) + +type configurer interface { + WriteFile(os.Host, string, string, string) error + LineIntoFile(os.Host, string, string, string) error + ReadFile(os.Host, string) (string, error) + FileExist(os.Host, string) bool + DeleteFile(os.Host, string) error + Stat(os.Host, string, ...exec.Option) (*os.FileInfo, error) +} + +// Host is a host that utilizes rig for connections +type Host struct { + rig.Connection + + Configurer configurer +} + +// LoadOS is a function that assigns a OS support package to the host and +// typecasts it to a suitable interface +func (h *Host) LoadOS() error { + bf, err := registry.GetOSModuleBuilder(*h.OSVersion) + if err != nil { + return err + } + + h.Configurer = bf().(configurer) + + return nil +} + +func main() { + dh := flag.String("host", "127.0.0.1", "target host [+ :port], can give multiple comma separated") + usr := flag.String("user", "root", "user name") + kp := flag.String("keypath", "", "keypath") + pc := flag.Bool("askpass", false, "ask passwords") + + fn := fmt.Sprintf("test_%s.txt", time.Now().Format("20060102150405")) + + flag.Parse() + + if *dh == "" { + println("see -help") + goos.Exit(1) + } + + if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { + f, err := goos.Open(configPath) + if err != nil { + panic(err) + } + cfg, err := ssh_config.Decode(f) + if err != nil { + panic(err) + } + rig.SSHConfigGetAll = func(dst, key string) []string { + res, err := cfg.GetAll(dst, key) + if err != nil { + return nil + } + return res + } + } + + var passfunc func() (string, error) + if *pc { + passfunc = func() (string, error) { + var pass string + fmt.Print("Password: ") + fmt.Scanln(&pass) + return pass, nil + } + } + + var hosts []Host + + for _, address := range strings.Split(*dh, ",") { + port := 22 + if addr, portstr, ok := strings.Cut(address, ":"); ok { + address = addr + p, err := strconv.Atoi(portstr) + if err != nil { + panic("invalid port " + portstr) + } + port = p + } + + h := Host{ + Connection: rig.Connection{ + SSH: &rig.SSH{ + Address: address, + Port: port, + User: *usr, + KeyPath: kp, + PasswordCallback: passfunc, + }, + }, + } + hosts = append(hosts, h) + } + + for _, h := range hosts { + if err := h.Connect(); err != nil { + panic(err) + } + + if err := h.LoadOS(); err != nil { + panic(err) + } + + if err := h.Configurer.WriteFile(h, fn, "test\ntest2\ntest3", "0644"); err != nil { + panic(err) + } + + if err := h.Configurer.LineIntoFile(h, fn, "test2", "test4"); err != nil { + panic(err) + } + + if !h.Configurer.FileExist(h, fn) { + panic("file does not exist") + } + + row, err := h.Configurer.ReadFile(h, fn) + if err != nil { + panic(err) + } + if row != "test\ntest4\ntest3" { + panic("file content is not correct") + } + + stat, err := h.Configurer.Stat(h, fn) + if err != nil { + panic(err) + } + if !strings.HasSuffix(stat.FName, fn) { + panic("file stat is not correct") + } + + if err := h.Configurer.DeleteFile(h, fn); err != nil { + panic(err) + } + + if h.Configurer.FileExist(h, fn) { + panic("file still exists") + } + } +} diff --git a/connection.go b/connection.go index 55c7c10f..416d59bb 100644 --- a/connection.go +++ b/connection.go @@ -86,9 +86,8 @@ func (c *Connection) SetDefaults() { if c.client == nil { c.client = defaultClient() } + _ = defaults.Set(c.client) } - - _ = defaults.Set(c.client) } // Protocol returns the connection protocol name @@ -133,16 +132,11 @@ func (c *Connection) IsConnected() bool { // String returns a printable representation of the connection, which will look // like: `[ssh] address:port` func (c Connection) String() string { - client := c.client - if client == nil { - client = c.configuredClient() - _ = defaults.Set(c) - } - if client == nil { - client = defaultClient() + if c.client == nil { + return fmt.Sprintf("[%s] %s", c.Protocol(), c.Address()) } - return client.String() + return c.client.String() } // IsWindows returns true on windows hosts @@ -311,9 +305,7 @@ func (c *Connection) configuredClient() client { } func defaultClient() client { - c := &Localhost{Enabled: true} - _ = defaults.Set(c) - return c + return &Localhost{Enabled: true} } // GroupParams separates exec.Options from other sprintf templating args diff --git a/go.mod b/go.mod index 227f2862..467148b0 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/davidmz/go-pageant v1.0.2 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 + github.com/kevinburke/ssh_config v1.2.0 github.com/masterzen/winrm v0.0.0-20220917170901-b07f6cb0598d github.com/mitchellh/go-homedir v1.1.0 github.com/stretchr/testify v1.8.0 diff --git a/go.sum b/go.sum index 90f59016..c8ff7fc4 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= +github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/log/log.go b/log/log.go index 8b9f8706..f0a9c640 100644 --- a/log/log.go +++ b/log/log.go @@ -4,6 +4,7 @@ import "fmt" // Logger interface should be implemented by the logging library you wish to use type Logger interface { + Tracef(string, ...interface{}) Debugf(string, ...interface{}) Infof(string, ...interface{}) Errorf(string, ...interface{}) @@ -12,6 +13,11 @@ type Logger interface { // Log can be assigned a proper logger, such as logrus configured to your liking. var Log Logger +// Tracef logs a trace level log message +func Tracef(t string, args ...interface{}) { + Log.Debugf(t, args...) +} + // Debugf logs a debug level log message func Debugf(t string, args ...interface{}) { Log.Debugf(t, args...) @@ -32,6 +38,11 @@ type StdLog struct { Logger } +// Debugf prints a debug level log message +func (l *StdLog) Tracef(t string, args ...interface{}) { + fmt.Println("TRACE", fmt.Sprintf(t, args...)) +} + // Debugf prints a debug level log message func (l *StdLog) Debugf(t string, args ...interface{}) { fmt.Println("DEBUG", fmt.Sprintf(t, args...)) diff --git a/ssh.go b/ssh.go index b429c696..877da115 100644 --- a/ssh.go +++ b/ssh.go @@ -16,6 +16,8 @@ import ( "strings" "sync" + "github.com/creasty/defaults" + "github.com/kevinburke/ssh_config" ssh "golang.org/x/crypto/ssh" "golang.org/x/term" @@ -29,42 +31,106 @@ import ( "github.com/mitchellh/go-homedir" ) +var authMethodCache = sync.Map{} + // SSH describes an SSH connection type SSH struct { Address string `yaml:"address" validate:"required,hostname|ip"` User string `yaml:"user" validate:"required" default:"root"` Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` - KeyPath string `yaml:"keyPath" validate:"omitempty"` + KeyPath *string `yaml:"keyPath" validate:"omitempty"` HostKey string `yaml:"hostKey,omitempty"` Bastion *SSH `yaml:"bastion,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` name string - isWindows bool - knowOs bool - keypathDefault bool + isWindows bool + knowOs bool + once sync.Once client *ssh.Client + + keyPaths []string } type PasswordCallback func() (secret string, err error) -const DefaultKeypath = "~/.ssh/id_rsa" +var defaultKeypaths = []string{"~/.ssh/id_rsa", "~/.ssh/identity", "~/.ssh/id_dsa"} +var dummyHostKeypaths []string +var globalOnce sync.Once -// SetDefaults sets various default values -func (c *SSH) SetDefaults() { - if c.KeyPath == "" { - c.KeyPath = DefaultKeypath - c.keypathDefault = true +func (c *SSH) expandKeypath(path string) (string, bool) { + expanded, err := homedir.Expand(path) + if err != nil { + return "", false } - if k, err := homedir.Expand(c.KeyPath); err == nil { - c.KeyPath = k + _, err = os.Stat(expanded) + if err != nil { + log.Debugf("%s: identity file %s not found", c, expanded) + return "", false + } + log.Tracef("%s: found identity file %s", c, expanded) + return expanded, true +} + +func (c *SSH) keypathsFromConfig() []string { + log.Tracef("%s: trying to get a keyfile path from ssh config", c) + if idf := c.getConfigAll("IdentityFile"); len(idf) > 0 { + log.Tracef("%s: detected %d identity file paths from ssh config", c, len(idf)) + return idf } + return []string{} } -// KeyPathDefaulted returns true if the keypath was not set by the user -func (c *SSH) KeyPathDefaulted() bool { - return c.keypathDefault +// SetDefaults sets various default values +func (c *SSH) SetDefaults() { + globalOnce.Do(func() { + log.Tracef("discovering global default keypaths") + dummyHostIdentityFiles := SSHConfigGetAll(hopefullyNonexistentHost, "IdentityFile") + for _, keyPath := range dummyHostIdentityFiles { + if expanded, ok := c.expandKeypath(keyPath); ok { + dummyHostKeypaths = append(dummyHostKeypaths, expanded) + } + } + }) + c.once.Do(func() { + if c.KeyPath != nil && *c.KeyPath != "" { + if expanded, ok := c.expandKeypath(*c.KeyPath); ok { + c.keyPaths = append(c.keyPaths, expanded) + } + return + } + c.KeyPath = nil + + paths := c.keypathsFromConfig() + if len(paths) == 0 { + paths = append(paths, defaultKeypaths...) + } + + for _, p := range paths { + if expanded, ok := c.expandKeypath(p); ok { + log.Debugf("%s: using identity file %s", c, expanded) + c.keyPaths = append(c.keyPaths, expanded) + } + } + + for _, keyPath := range c.keyPaths { + found := false + for _, idf := range dummyHostKeypaths { + if idf == keyPath { + found = true + break + } + } + if !found { + // found a keypath that is not in the dummy host's config, so it's not a global default + // set this as c.KeyPath so we can consider it an explicitly set keypath + c.KeyPath = &keyPath + break + } + } + + }) } // Protocol returns the protocol name, "SSH" @@ -77,6 +143,18 @@ func (c *SSH) IPAddress() string { return c.Address } +// SSHConfigGetAll by default points to ssh_config package's GetAll() function +// you can override it with your own implementation for testing purposes +var SSHConfigGetAll = ssh_config.GetAll + +func (c *SSH) getConfigAll(key string) []string { + dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) + if val := SSHConfigGetAll(dst, key); len(val) > 0 { + return val + } + return SSHConfigGetAll(c.Address, key) +} + // String returns the connection's printable name func (c *SSH) String() string { if c.name == "" { @@ -125,17 +203,12 @@ func trustedHostKeyCallback(trustedKey string) ssh.HostKeyCallback { } } -// signersToString returns signers key type and sha256 fingerprint -func signersToString(signers []ssh.Signer) string { - var ret strings.Builder - for _, s := range signers { - ret.WriteString("- " + keyString(s.PublicKey()) + "\n") - } - return ret.String() -} +const hopefullyNonexistentHost = "thisH0stDoe5not3xist" // Connect opens the SSH connection func (c *SSH) Connect() error { + _ = defaults.Set(c) + config := &ssh.ClientConfig{ User: c.User, } @@ -146,22 +219,43 @@ func (c *SSH) Connect() error { config.HostKeyCallback = trustedHostKeyCallback(c.HostKey) } - privateKeyAuth, err := c.getPrivateKeys() + var signers []ssh.Signer + agent, err := agentClient() if err != nil { - return err - } - if len(privateKeyAuth) > 0 { - config.Auth = append(config.Auth, privateKeyAuth...) + log.Tracef("%s: failed to get ssh agent client: %v", c, err) + } else { + signers, err = agent.Signers() + if err != nil { + log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) + } } - if c.KeyPath == "" || c.KeyPathDefaulted() { - signers, err := getSshAgentSigners() + for _, keyPath := range c.keyPaths { + if am, ok := authMethodCache.Load(keyPath); ok { + switch authM := am.(type) { + case ssh.AuthMethod: + log.Tracef("%s: using cached auth method for %s", c, keyPath) + config.Auth = append(config.Auth, authM) + case error: + log.Tracef("%s: already discarded key %s: %v", c, keyPath, authM) + default: + log.Tracef("%s: unexpected type %T for cached auth method for %s", c, am, keyPath) + } + continue + } + privateKeyAuth, err := c.pkeySigner(signers, keyPath) if err != nil { - log.Debugf("failed to get signers from SSH agents: %v", err) - } else if len(signers) > 0 { - log.Debugf("Got %v signers from SSH agents:\n%s", len(signers), signersToString(signers)) - config.Auth = append(config.Auth, ssh.PublicKeys(signers...)) + log.Debugf("%s: failed to obtain a signer for identity %s: %v", c, keyPath, err) + // store the error so this key won't be loaded again + authMethodCache.Store(keyPath, err) } + authMethodCache.Store(keyPath, privateKeyAuth) + config.Auth = append(config.Auth, privateKeyAuth) + } + + if c.KeyPath == nil && len(signers) > 0 { + log.Debugf("%s: using all keys (%d) from ssh agent because a keypath was not explicitly given", c, len(signers)) + config.Auth = append(config.Auth, ssh.PublicKeys(signers...)) } dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) @@ -169,10 +263,11 @@ func (c *SSH) Connect() error { var client *ssh.Client if c.Bastion == nil { - client, err = ssh.Dial("tcp", dst, config) + clientDirect, err := ssh.Dial("tcp", dst, config) if err != nil { return err } + client = clientDirect } else { if err := c.Bastion.Connect(); err != nil { return err @@ -192,37 +287,66 @@ func (c *SSH) Connect() error { return nil } -func (c *SSH) getPrivateKeys() ([]ssh.AuthMethod, error) { - key, err := os.ReadFile(c.KeyPath) - if err != nil { - if c.KeyPathDefaulted() { - return nil, nil +func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { + if len(signers) == 0 { + return nil, fmt.Errorf("no signer available for public key") + } + + for _, s := range signers { + if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) { + log.Debugf("%s: signer for public key available in ssh agent", c) + return ssh.PublicKeys(s), nil } + } + + return nil, fmt.Errorf("the provided key is a public key and is not known by agent") +} + +func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { + log.Tracef("%s: checking identity file %s", c, path) + key, err := os.ReadFile(path) + if err != nil { return nil, err } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(key) + if err == nil { + log.Debugf("%s: file %s is a public key", c, path) + return c.pubkeySigner(signers, pubKey) + } + signer, err := ssh.ParsePrivateKey(key) - switch err.(type) { - case nil: - return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil - case *ssh.PassphraseMissingError: - if c.PasswordCallback != nil { - auth := ssh.PublicKeysCallback(func() ([]ssh.Signer, error) { - pass, err := c.PasswordCallback() - if err != nil { - return nil, fmt.Errorf("password provider failed: %s", err) - } - signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(pass)) - if err != nil { - return nil, err + if err == nil { + log.Debugf("%s: using an unencrypted private key from %s", c, path) + return ssh.PublicKeys(signer), nil + } + + if _, ok := err.(*ssh.PassphraseMissingError); ok { + log.Debugf("%s: key %s is encrypted", c, path) + + if len(signers) > 0 { + if pubkeyPath, ok := c.expandKeypath(path + ".pub"); ok { + if signer, err := c.pkeySigner(signers, pubkeyPath); err == nil { + return signer, nil } - return []ssh.Signer{signer}, nil - }) - return []ssh.AuthMethod{auth}, nil + } + } + + if c.PasswordCallback != nil { + log.Tracef("%s: asking for a password to decrypt %s", c, path) + pass, err := c.PasswordCallback() + if err != nil { + return nil, fmt.Errorf("password provider failed: %w", err) + } + signer, err := ssh.ParsePrivateKeyWithPassphrase(key, []byte(pass)) + if err != nil { + return nil, fmt.Errorf("protected key decoding failed: %w", err) + } + return ssh.PublicKeys(signer), nil } } - return nil, fmt.Errorf("can't parse keyfile %s: %w", c.KeyPath, err) + return nil, fmt.Errorf("can't parse keyfile %s: %w", path, err) } // Exec executes a command on the host diff --git a/ssh_agent.go b/ssh_agent.go index 5790a608..2d260300 100644 --- a/ssh_agent.go +++ b/ssh_agent.go @@ -7,23 +7,19 @@ import ( "net" "os" - ssh "golang.org/x/crypto/ssh" + "github.com/k0sproject/rig/log" "golang.org/x/crypto/ssh/agent" ) -// getSshAgentSigners returns non empty list of signers from a SSH agent -func getSshAgentSigners() ([]ssh.Signer, error) { +func agentClient() (agent.Agent, error) { sshAgentSock := os.Getenv("SSH_AUTH_SOCK") if sshAgentSock == "" { return nil, fmt.Errorf("SSH_AUTH_SOCK is empty") } + log.Debugf("using SSH_AUTH_SOCK=%s", sshAgentSock) sshAgent, err := net.Dial("unix", sshAgentSock) if err != nil { return nil, fmt.Errorf("can't connect to SSH agent: %w", err) } - signers, err := agent.NewClient(sshAgent).Signers() - if err != nil { - return nil, fmt.Errorf("SSH agent new client: %w", err) - } - return signers, nil + return agent.NewClient(sshAgent), nil } diff --git a/ssh_agent_windows.go b/ssh_agent_windows.go index 3a5bad50..982efc5c 100644 --- a/ssh_agent_windows.go +++ b/ssh_agent_windows.go @@ -16,45 +16,13 @@ const ( openSshAgentPipe = `\\.\pipe\openssh-ssh-agent` ) -// getSshAgentSigners returns non empty list of signers from a SSH agent -func getSshAgentSigners() ([]ssh.Signer, error) { - var ( - errors []string - signers []ssh.Signer - ) - +func agentClient() (agent.Agent, error) { if pageant.Available() { - signersPageant, errGetSigners := pageant.New().Signers() - if errGetSigners != nil { - errors = append(errors, fmt.Sprintf("- Failed to get signers from Pageant: %s", errGetSigners)) - } else { - if len(signersPageant) > 0 { - signers = append(signers, signersPageant...) - } else { - errors = append(errors, "- No keys loaded in Pageant") - } - } - } else { - errors = append(errors, "- Pageant is unavailable") + return pageant.New(), nil } - sock, err := winio.DialPipe(openSshAgentPipe, nil) if err != nil { - errors = append(errors, fmt.Sprintf("- Can't connect to openssh-agent: %s", err)) - } else { - signersOpenSSHAgent, errGetSigners := agent.NewClient(sock).Signers() - if errGetSigners != nil { - errors = append(errors, fmt.Sprintf("- Failed to get signers from openssh-agent: %s", errGetSigners)) - } else { - if len(signersOpenSSHAgent) > 0 { - signers = append(signers, signersOpenSSHAgent...) - } else { - errors = append(errors, "- No keys loaded in openssh-agent") - } - } - } - if len(signers) > 0 { - return signers, nil + return nil, err } - return nil, fmt.Errorf("%s", strings.Join(errors, "\n")) + return agent.NewClient(sock), nil } diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 00000000..79b3dc11 --- /dev/null +++ b/test/Makefile @@ -0,0 +1,71 @@ +KEY_TYPE ?= rsa +KEY_SIZE ?= 4096 +KEY_PASSPHRASE ?= "" +KEY_PATH ?= ".ssh/identity" +REPLICAS ?= 1 + +footloose := $(shell which footloose) +ifeq ($(footloose),) +footloose := $(shell go env GOPATH)/bin/footloose +endif + +envsubst := $(shell which envsubst) +ifeq ($(envsubst),) +$(error 'envsubst' NOT found in path, please install it and re-run) +endif + +sshkeygen := $(shell which ssh-keygen) +ifeq ($(sshkeygen),) +$(error 'ssh-keygen' NOT found in path, please install it and re-run) +endif + +.PHONY: rigtest +rigtest: + go build -o rigtest ../cmd/rigtest + +$(footloose): + go install github.com/weaveworks/footloose/...@0.6.3 + +.ssh: + mkdir -p .ssh + +.ssh/identity: .ssh + rm -f .ssh/identity + ssh-keygen -t $(KEY_TYPE) -b $(KEY_SIZE) -f .ssh/identity -N $(KEY_PASSPHRASE) + +footloose.yaml: .ssh/identity $(footloose) + $(footloose) config create \ + --config footloose.yaml \ + --image quay.io/footloose/ubuntu18.04 \ + --name rigtest \ + --key .ssh/identity \ + --replicas $(REPLICAS) \ + --override + +.PHONY: create-host +create-host: footloose.yaml + $(footloose) create -c footloose.yaml + +.PHONY: delete-host +delete-host: footloose.yaml + $(footloose) delete -c footloose.yaml + +.PHONY: clean +clean: delete-host + rm -f footloose.yaml identity rigtest + rm -rf .ssh + +.PHONY: sshport +sshport: + @$(footloose) show node0 -o json|grep hostPort|grep -oE "[0-9]+" + +.PHONY: run +run: rigtest create-host + ./rigtest \ + -host 127.0.0.1:$(shell $(MAKE) sshport) \ + -keypath $(KEY_PATH) \ + -user root + +.PHONY: test +test: rigtest + ./test.sh diff --git a/test/footloose.yaml b/test/footloose.yaml new file mode 100644 index 00000000..718a8dc3 --- /dev/null +++ b/test/footloose.yaml @@ -0,0 +1,11 @@ +cluster: + name: rigtest + privateKey: .ssh/identity +machines: +- count: 1 + spec: + backend: docker + image: quay.io/footloose/ubuntu18.04 + name: node%d + portMappings: + - containerPort: 22 diff --git a/test/test.sh b/test/test.sh new file mode 100755 index 00000000..b3f40055 --- /dev/null +++ b/test/test.sh @@ -0,0 +1,125 @@ +#!/bin/bash + +set -e + +color_echo() { + echo -e "\033[1;31m$@\033[0m" +} + +ssh_port() { + footloose show $1 -o json|grep hostPort|grep -oE "[0-9]+" +} + +rig_test_agent_with_public_key() { + color_echo "- Testing connection using agent and providing a path to public key" + make create-host + eval $(ssh-agent -s) + ssh-add .ssh/identity + rm -f .ssh/identity + set +e + HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub + local exit_code=$? + set -e + kill $SSH_AGENT_PID + export SSH_AGENT_PID= + export SSH_AUTH_SOCK= + return $exit_code +} + +rig_test_agent_with_private_key() { + color_echo "- Testing connection using agent and providing a path to protected private key" + make create-host KEY_PASSPHRASE=foo + eval $(ssh-agent -s) + expect -c ' + spawn ssh-add .ssh/identity + expect "?:" + send "foo\n" + expect eof" + ' + set +e + # path points to a private key, rig should try to look for the .pub for it + HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity + local exit_code=$? + set -e + kill $SSH_AGENT_PID + export SSH_AGENT_PID= + export SSH_AUTH_SOCK= + return $exit_code +} + +rig_test_agent() { + color_echo "- Testing connection using any key from agent (empty keypath)" + make create-host + eval $(ssh-agent -s) + ssh-add .ssh/identity + rm -f .ssh/identity + set +e + ssh-add -l + HOME=. SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" + local exit_code=$? + set -e + kill $SSH_AGENT_PID + export SSH_AGENT_PID= + export SSH_AUTH_SOCK= + return $exit_code +} + +rig_test_ssh_config() { + color_echo "- Testing getting identity path from ssh config" + make create-host + mv .ssh/identity .ssh/identity2 + echo "Host 127.0.0.1:$(ssh_port node0)" > .ssh/config + echo " IdentityFile .ssh/identity2" >> .ssh/config + set +e + HOME=. SSH_CONFIG=.ssh/config ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root + local exit_code=$? + set -e + return $exit_code +} + +rig_test_key_from_path() { + color_echo "- Testing regular keypath" + make create-host + mv .ssh/identity .ssh/identity2 + set +e + ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity2 + local exit_code=$? + set -e + return $exit_code +} + +rig_test_protected_key_from_path() { + color_echo "- Testing regular keypath to encrypted key, two hosts" + make create-host KEY_PASSPHRASE=foo REPLICAS=2 + set +e + ssh_port node0 > .ssh/port_A + ssh_port node1 > .ssh/port_B + expect -c ' + + set fp [open .ssh/port_A r] + set PORTA [read -nonewline $fp] + close $fp + set fp [open .ssh/port_B r] + set PORTB [read -nonewline $fp] + close $fp + + spawn ./rigtest -host 127.0.0.1:$PORTA,127.0.0.1:$PORTB -user root -keypath .ssh/identity -askpass true + expect "Password:" + send "foo\n" + expect eof" + ' $port1 $port2 + local exit_code=$? + set -e + return $exit_code +} + +for test in $(declare -F|grep rig_test_|cut -d" " -f3); do + if [ "$FOCUS" != "" ] && [ "$FOCUS" != "$test" ]; then + continue + fi + make clean + make rigtest + color_echo "\n###########################################################" + $test + echo -e "\n\n\n" +done