diff --git a/cmd/gpu/setup.go b/cmd/gpu/setup.go index aa4c9b6..9c64b13 100644 --- a/cmd/gpu/setup.go +++ b/cmd/gpu/setup.go @@ -77,7 +77,7 @@ image vmi-docker-* for a ready-made base.`, // Phase 1: install toolkit + driver. code, err := internalssh.RunScript(sshClient, gpuInstallScript(), nil, os.Stdout, os.Stderr) - sshClient.Close() + _ = sshClient.Close() if err != nil { return fmt.Errorf("install script: %w", err) } @@ -105,7 +105,7 @@ image vmi-docker-* for a ready-made base.`, if err != nil { return err } - defer sshClient.Close() + defer func() { _ = sshClient.Close() }() // Phase 4: install nvidia-utils + verify. code, err = internalssh.RunScript(sshClient, gpuVerifyScript(), nil, os.Stdout, os.Stderr) diff --git a/cmd/root.go b/cmd/root.go index 32c29b6..f0d341b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -31,13 +31,14 @@ import ( var ( version = "dev" - flagProfile string - flagFormat string - flagNoInput bool - flagQuiet bool - flagVerbose bool - flagNoColor bool - flagYes bool + flagProfile string + flagFormat string + flagNoInput bool + flagQuiet bool + flagVerbose bool + flagNoColor bool + flagYes bool + flagSSHInsecure bool ) // rootCmd is the base command. @@ -61,6 +62,9 @@ var rootCmd = &cobra.Command{ if flagNoColor { _ = os.Setenv(config.EnvNoColor, "1") } + if flagSSHInsecure { + _ = os.Setenv(config.EnvSSHInsecure, "1") + } }, } @@ -79,6 +83,7 @@ func init() { rootCmd.PersistentFlags().BoolVar(&flagQuiet, "quiet", false, "suppress non-essential output") rootCmd.PersistentFlags().BoolVar(&flagNoColor, "no-color", false, "disable color output") rootCmd.PersistentFlags().BoolVarP(&flagYes, "yes", "y", false, "skip confirmation prompts") + rootCmd.PersistentFlags().BoolVar(&flagSSHInsecure, "insecure", false, "disable SSH host-key verification (not recommended; for lab / throwaway VPS)") rootCmd.PersistentFlags().Bool("no-headers", false, "hide table/CSV headers") rootCmd.PersistentFlags().StringArray("filter", nil, "filter rows by key=value (repeatable)") rootCmd.PersistentFlags().String("sort-by", "", "sort rows by field name") diff --git a/go.mod b/go.mod index cec05ca..8689246 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.26.1 require ( github.com/manifoldco/promptui v0.9.0 github.com/spf13/cobra v1.10.2 + github.com/spf13/pflag v1.0.9 golang.org/x/crypto v0.49.0 golang.org/x/term v0.41.0 gopkg.in/yaml.v3 v3.0.1 @@ -13,6 +14,5 @@ require ( require ( github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/spf13/pflag v1.0.9 // indirect golang.org/x/sys v0.42.0 // indirect ) diff --git a/internal/config/env.go b/internal/config/env.go index 862c667..4edaf04 100644 --- a/internal/config/env.go +++ b/internal/config/env.go @@ -17,6 +17,7 @@ const ( EnvDebug = "CONOHA_DEBUG" EnvYes = "CONOHA_YES" EnvNoColor = "CONOHA_NO_COLOR" + EnvSSHInsecure = "CONOHA_SSH_INSECURE" ) // EnvOr returns the environment variable value if set, otherwise the fallback. @@ -46,3 +47,10 @@ func IsNoColor() bool { _, noColor := os.LookupEnv("NO_COLOR") return noColor } + +// IsSSHInsecure returns true when SSH host-key verification should be +// disabled (InsecureIgnoreHostKey). Set via --insecure flag or the env var. +// Default false — real known_hosts verification with TOFU fallback. +func IsSSHInsecure() bool { + return os.Getenv(EnvSSHInsecure) == "1" || os.Getenv(EnvSSHInsecure) == "true" +} diff --git a/internal/ssh/exec.go b/internal/ssh/exec.go index 5ad2df6..a973c79 100644 --- a/internal/ssh/exec.go +++ b/internal/ssh/exec.go @@ -9,9 +9,13 @@ import ( "time" "golang.org/x/crypto/ssh" + + configpkg "github.com/crowdy/conoha-cli/internal/config" ) -// ConnectConfig holds SSH connection parameters. +// ConnectConfig holds SSH connection parameters. Host-key verification is +// controlled globally via the --insecure flag / CONOHA_SSH_INSECURE env var +// (see configpkg.IsSSHInsecure); there is no per-call opt-out by design. type ConnectConfig struct { Host string // IP or hostname Port string // default "22" @@ -38,15 +42,20 @@ func Connect(cfg ConnectConfig) (*ssh.Client, error) { return nil, fmt.Errorf("parse key %s: %w", cfg.KeyPath, err) } - config := &ssh.ClientConfig{ + hostKeyCB, err := HostKeyCallback(configpkg.IsSSHInsecure(), configpkg.IsNoInput()) + if err != nil { + return nil, err + } + + clientCfg := &ssh.ClientConfig{ User: cfg.User, Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), // personal VPS use + HostKeyCallback: hostKeyCB, Timeout: 30 * time.Second, } addr := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) - return ssh.Dial("tcp", addr, config) + return ssh.Dial("tcp", addr, clientCfg) } // RunScript uploads and executes a script on the remote server. diff --git a/internal/ssh/knownhosts.go b/internal/ssh/knownhosts.go new file mode 100644 index 0000000..3d3aabb --- /dev/null +++ b/internal/ssh/knownhosts.go @@ -0,0 +1,149 @@ +package ssh + +import ( + "bufio" + "fmt" + "net" + "os" + "path/filepath" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" + "golang.org/x/term" +) + +// HostKeyCallback returns an ssh.HostKeyCallback that verifies the remote +// host key against ~/.ssh/known_hosts. On first connect to an unknown host +// it prompts the operator to accept and pin the key (TOFU). When noInput +// is true (CONOHA_NO_INPUT) or stdin is not a TTY, the connection fails +// rather than silently trusting. +// +// insecure=true returns the legacy InsecureIgnoreHostKey callback for lab +// and throwaway-VPS use; documented as the explicit opt-out for operators +// who knowingly want the old v0.1.x behavior back. +func HostKeyCallback(insecure, noInput bool) (ssh.HostKeyCallback, error) { + if insecure { + return ssh.InsecureIgnoreHostKey(), nil //nolint:gosec // user-requested via --insecure + } + + path, err := knownHostsPath() + if err != nil { + return nil, err + } + + // knownhosts.New rejects a missing file. Create an empty one so the + // TOFU prompt path can append to it on first use. + if _, err := os.Stat(path); os.IsNotExist(err) { + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return nil, fmt.Errorf("creating %s dir: %w", filepath.Dir(path), err) + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return nil, fmt.Errorf("creating %s: %w", path, err) + } + _ = f.Close() + } + + strict, err := knownhosts.New(path) + if err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if err := strict(hostname, remote, key); err == nil { + return nil + } else if kkErr, ok := err.(*knownhosts.KeyError); ok { + if len(kkErr.Want) > 0 { + // Key mismatch — never auto-accept; this is a MITM signal. + return &HostKeyMismatchError{Host: hostname, Path: path, Err: kkErr} + } + // Unknown host: TOFU prompt — only when stdin is genuinely + // interactive. A non-TTY stdin (CI, build script piping a + // heredoc, wrapper without --no-input) would otherwise let + // `yes\n` from an untrusted source silently trust the host. + if noInput || !term.IsTerminal(int(os.Stdin.Fd())) { + return fmt.Errorf("host %s not in %s and stdin is not interactive (no-input mode or non-TTY) — refusing to trust unknown host. Add manually with ssh-keyscan or use --insecure", hostname, path) + } + return promptAndPin(path, hostname, remote, key) + } else { + return err + } + }, nil +} + +// HostKeyMismatchError is returned when the server presents a host key that +// disagrees with the one pinned in known_hosts. Deliberately distinct from a +// plain error so callers can print MITM-specific guidance. +type HostKeyMismatchError struct { + Host string + Path string + Err error +} + +func (e *HostKeyMismatchError) Error() string { + return fmt.Sprintf( + "host key for %s has changed! This is either the server was rebuilt or a man-in-the-middle attack.\n"+ + " Pinned in: %s\n"+ + " Underlying: %v\n"+ + "If you just rebuilt the VPS, run: ssh-keygen -R %s (removes the old pin, next connect re-pins).", + e.Host, e.Path, e.Err, e.Host) +} + +func (e *HostKeyMismatchError) Unwrap() error { return e.Err } + +// promptAndPin asks the user to accept the unknown key, then appends it to +// known_hosts in the canonical OpenSSH format. +func promptAndPin(path, hostname string, remote net.Addr, key ssh.PublicKey) error { + fp := ssh.FingerprintSHA256(key) + fmt.Fprintf(os.Stderr, "\nThe authenticity of host %q can't be established.\n", hostname) + fmt.Fprintf(os.Stderr, "%s key fingerprint is %s.\n", key.Type(), fp) + fmt.Fprint(os.Stderr, "Are you sure you want to continue connecting (yes/no)? ") + + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("reading prompt answer: %w", err) + } + answer := strings.TrimSpace(strings.ToLower(line)) + if answer != "yes" && answer != "y" { + return fmt.Errorf("host %s key rejected by user", hostname) + } + + // Canonical line: " " + // knownhosts.Normalize returns host[:port] → host when port is 22. + addr := knownhosts.Normalize(hostname) + // Also include the numeric address so that later SSH sessions by IP + // (common in this CLI — we connect to IPs, not names) also match. + addrs := []string{addr} + if _, ok := remote.(*net.TCPAddr); ok { + if na := knownhosts.Normalize(remote.String()); na != addr { + addrs = append(addrs, na) + } + } + line = knownhosts.Line(addrs, key) + + f, err := os.OpenFile(path, os.O_APPEND|os.O_WRONLY, 0o600) + if err != nil { + return fmt.Errorf("opening %s for append: %w", path, err) + } + defer f.Close() + if _, err := f.WriteString(line + "\n"); err != nil { + return fmt.Errorf("writing %s: %w", path, err) + } + fmt.Fprintf(os.Stderr, "Warning: Permanently added %q (%s) to the list of known hosts.\n", hostname, key.Type()) + return nil +} + +// knownHostsPath returns the path to the user's known_hosts file. +// Honors SSH_KNOWN_HOSTS override for tests and bespoke setups. +func knownHostsPath() (string, error) { + if p := os.Getenv("SSH_KNOWN_HOSTS"); p != "" { + return p, nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving $HOME: %w", err) + } + return filepath.Join(home, ".ssh", "known_hosts"), nil +} diff --git a/internal/ssh/knownhosts_test.go b/internal/ssh/knownhosts_test.go new file mode 100644 index 0000000..c8fd2a4 --- /dev/null +++ b/internal/ssh/knownhosts_test.go @@ -0,0 +1,146 @@ +package ssh + +import ( + "crypto/ed25519" + "crypto/rand" + "errors" + "net" + "os" + "path/filepath" + "strings" + "testing" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" + "golang.org/x/term" +) + +func TestHostKeyCallback_Insecure(t *testing.T) { + cb, err := HostKeyCallback(true, false) + if err != nil { + t.Fatalf("Insecure path should never error, got %v", err) + } + // The insecure callback accepts any key without reading known_hosts. + key := genKey(t) + if err := cb("example.com:22", fakeTCPAddr(t, "1.2.3.4:22"), key); err != nil { + t.Errorf("Insecure callback rejected key: %v", err) + } +} + +func TestHostKeyCallback_MismatchIsDistinctError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + t.Setenv("SSH_KNOWN_HOSTS", path) + + // Pre-populate known_hosts with a key for "example.com:22", then hand + // the callback a different key for the same host. + pinned := genKey(t) + line := knownhosts.Line([]string{knownhosts.Normalize("example.com:22")}, pinned) + if err := os.WriteFile(path, []byte(line+"\n"), 0o600); err != nil { + t.Fatalf("seeding known_hosts: %v", err) + } + + cb, err := HostKeyCallback(false, true /* noInput — avoids TOFU prompt */) + if err != nil { + t.Fatalf("HostKeyCallback: %v", err) + } + + other := genKey(t) + err = cb("example.com:22", fakeTCPAddr(t, "1.2.3.4:22"), other) + if err == nil { + t.Fatal("expected a mismatch error, got nil") + } + var mismatch *HostKeyMismatchError + if !errors.As(err, &mismatch) { + t.Fatalf("expected HostKeyMismatchError, got %T: %v", err, err) + } + if mismatch.Host != "example.com:22" { + t.Errorf("mismatch.Host = %q, want example.com:22", mismatch.Host) + } + if !strings.Contains(mismatch.Error(), "ssh-keygen -R") { + t.Errorf("mismatch error should suggest ssh-keygen -R, got: %s", mismatch.Error()) + } +} + +func TestHostKeyCallback_UnknownHost_NoInput(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + t.Setenv("SSH_KNOWN_HOSTS", path) + + cb, err := HostKeyCallback(false, true /* noInput */) + if err != nil { + t.Fatalf("HostKeyCallback: %v", err) + } + + key := genKey(t) + err = cb("new-host:22", fakeTCPAddr(t, "1.2.3.4:22"), key) + if err == nil { + t.Fatal("expected refusal in no-input mode, got nil") + } + if !strings.Contains(err.Error(), "not in") || !strings.Contains(err.Error(), "--insecure") { + t.Errorf("expected helpful no-input error message, got: %v", err) + } +} + +func TestHostKeyCallback_UnknownHost_NonTTYFailsClosed(t *testing.T) { + // Under `go test`, stdin is typically a pipe; if a developer runs the + // suite from a real terminal we skip rather than block on a prompt. + if term.IsTerminal(int(os.Stdin.Fd())) { + t.Skip("stdin is a TTY; non-TTY guard cannot be exercised here") + } + + dir := t.TempDir() + path := filepath.Join(dir, "known_hosts") + t.Setenv("SSH_KNOWN_HOSTS", path) + + cb, err := HostKeyCallback(false, false /* noInput=false on purpose */) + if err != nil { + t.Fatalf("HostKeyCallback: %v", err) + } + + key := genKey(t) + err = cb("new-host:22", fakeTCPAddr(t, "1.2.3.4:22"), key) + if err == nil { + t.Fatal("expected refusal when stdin is non-TTY, got nil") + } + if !strings.Contains(err.Error(), "non-TTY") { + t.Errorf("expected error to mention non-TTY, got: %v", err) + } +} + +func TestHostKeyCallback_CreatesMissingKnownHosts(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "ssh-sub", "known_hosts") + t.Setenv("SSH_KNOWN_HOSTS", path) + + if _, err := HostKeyCallback(false, true); err != nil { + t.Fatalf("HostKeyCallback should auto-create the file: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected %s to be created, got %v", path, err) + } +} + +// --- helpers --- + +func genKey(t *testing.T) ssh.PublicKey { + t.Helper() + pub, _, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("ed25519.GenerateKey: %v", err) + } + k, err := ssh.NewPublicKey(pub) + if err != nil { + t.Fatalf("ssh.NewPublicKey: %v", err) + } + return k +} + +func fakeTCPAddr(t *testing.T, s string) net.Addr { + t.Helper() + a, err := net.ResolveTCPAddr("tcp", s) + if err != nil { + t.Fatalf("ResolveTCPAddr: %v", err) + } + return a +}