diff --git a/go.mod b/go.mod index a7d8b6dadc9..f79599bb4e0 100644 --- a/go.mod +++ b/go.mod @@ -117,7 +117,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect - golang.org/x/crypto v0.45.0 // indirect + golang.org/x/crypto v0.45.0 golang.org/x/mod v0.29.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect golang.org/x/term v0.37.0 // indirect diff --git a/pkg/driver/vz/vm_darwin.go b/pkg/driver/vz/vm_darwin.go index 15cff0444e3..11fbce55595 100644 --- a/pkg/driver/vz/vm_darwin.go +++ b/pkg/driver/vz/vm_darwin.go @@ -113,18 +113,18 @@ func startVM(ctx context.Context, inst *limatype.Instance, sshLocalPort int) (vm useSSHOverVsock = b } } + hostAddress := net.JoinHostPort(inst.SSHAddress, strconv.Itoa(usernetSSHLocalPort)) if !useSSHOverVsock { logrus.Info("LIMA_SSH_OVER_VSOCK is false, skipping detection of SSH server on vsock port") - } else if err := usernetClient.WaitOpeningSSHPort(ctx, inst); err == nil { - hostAddress := net.JoinHostPort(inst.SSHAddress, strconv.Itoa(usernetSSHLocalPort)) - if err := wrapper.startVsockForwarder(ctx, 22, hostAddress); err == nil { - logrus.Infof("Detected SSH server is listening on the vsock port; changed %s to proxy for the vsock port", hostAddress) - usernetSSHLocalPort = 0 // disable gvisor ssh port forwarding - } else { - logrus.WithError(err).Warn("Failed to detect SSH server on vsock port, falling back to usernet forwarder") - } + } else if err := usernetClient.WaitOpeningSSHPort(ctx, inst); err != nil { + logrus.WithError(err).Info("Failed to wait for the guest SSH server to become available, falling back to usernet forwarder") + } else if err := wrapper.checkSSHOverVsockAvailable(ctx, inst); err != nil { + logrus.WithError(err).Info("Failed to detect SSH server on vsock port, falling back to usernet forwarder") + } else if err := wrapper.startVsockForwarder(ctx, 22, hostAddress); err != nil { + logrus.WithError(err).Info("Failed to start SSH server forwarder on vsock port, falling back to usernet forwarder") } else { - logrus.WithError(err).Warn("Failed to wait for the guest SSH server to become available, falling back to usernet forwarder") + logrus.Infof("Detected SSH server is listening on the vsock port; changed %s to proxy for the vsock port", hostAddress) + usernetSSHLocalPort = 0 // disable gvisor ssh port forwarding } err := usernetClient.ConfigureDriver(ctx, inst, usernetSSHLocalPort) if err != nil { diff --git a/pkg/driver/vz/vsock_forwarder.go b/pkg/driver/vz/vsock_forwarder.go index 044c3d5105a..6109d2e1373 100644 --- a/pkg/driver/vz/vsock_forwarder.go +++ b/pkg/driver/vz/vsock_forwarder.go @@ -12,17 +12,14 @@ import ( "github.com/containers/gvisor-tap-vsock/pkg/tcpproxy" "github.com/sirupsen/logrus" + + "github.com/lima-vm/lima/v2/pkg/limatype" + "github.com/lima-vm/lima/v2/pkg/sshutil" ) func (m *virtualMachineWrapper) startVsockForwarder(ctx context.Context, vsockPort uint32, hostAddress string) error { - // Test if the vsock port is open - conn, err := m.dialVsock(ctx, vsockPort) - if err != nil { - return err - } - conn.Close() // Start listening on localhost:hostPort and forward to vsock:vsockPort - _, _, err = net.SplitHostPort(hostAddress) + _, _, err := net.SplitHostPort(hostAddress) if err != nil { return err } @@ -73,3 +70,9 @@ func (m *virtualMachineWrapper) dialVsock(_ context.Context, port uint32) (conn } return nil, err } + +func (m *virtualMachineWrapper) checkSSHOverVsockAvailable(ctx context.Context, inst *limatype.Instance) error { + return sshutil.WaitSSHReady(ctx, func(ctx context.Context) (net.Conn, error) { + return m.dialVsock(ctx, uint32(22)) + }, "vsock:22", *inst.Config.User.Name, 1) +} diff --git a/pkg/networks/usernet/client.go b/pkg/networks/usernet/client.go index 6a2437c3bd7..9f8799677e7 100644 --- a/pkg/networks/usernet/client.go +++ b/pkg/networks/usernet/client.go @@ -140,8 +140,9 @@ func (c *Client) WaitOpeningSSHPort(ctx context.Context, inst *limatype.Instance if err != nil { return err } + user := *inst.Config.User.Name // -1 avoids both sides timing out simultaneously. - u := fmt.Sprintf("%s/extension/wait_port?ip=%s&port=22&timeout=%d", c.base, ipAddr, timeoutSeconds-1) + u := fmt.Sprintf("%s/extension/wait-ssh-server?ip=%s&port=22&timeout=%d&user=%s", c.base, ipAddr, timeoutSeconds-1, user) res, err := httpclientutil.Get(ctx, c.client, u) if err != nil { return err diff --git a/pkg/networks/usernet/gvproxy.go b/pkg/networks/usernet/gvproxy.go index 997fecd12ce..25edcafa35b 100644 --- a/pkg/networks/usernet/gvproxy.go +++ b/pkg/networks/usernet/gvproxy.go @@ -22,6 +22,8 @@ import ( "github.com/containers/gvisor-tap-vsock/pkg/virtualnetwork" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" + + "github.com/lima-vm/lima/v2/pkg/sshutil" ) type GVisorNetstackOpts struct { @@ -243,7 +245,7 @@ func httpServe(ctx context.Context, g *errgroup.Group, ln net.Listener, mux http func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux { m := n.Mux() - m.HandleFunc("/extension/wait_port", func(w http.ResponseWriter, r *http.Request) { + m.HandleFunc("/extension/wait-ssh-server", func(w http.ResponseWriter, r *http.Request) { ip := r.URL.Query().Get("ip") if net.ParseIP(ip) == nil { msg := fmt.Sprintf("invalid ip address: %s", ip) @@ -255,8 +257,14 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux { http.Error(w, err.Error(), http.StatusBadRequest) return } - port := uint16(port16) - addr := fmt.Sprintf("%s:%d", ip, port) + addr := net.JoinHostPort(ip, fmt.Sprintf("%d", uint16(port16))) + + user := r.URL.Query().Get("user") + if user == "" { + msg := "user query parameter is required" + http.Error(w, msg, http.StatusBadRequest) + return + } timeoutSeconds := 10 if timeoutString := r.URL.Query().Get("timeout"); timeoutString != "" { @@ -267,27 +275,14 @@ func muxWithExtension(n *virtualnetwork.VirtualNetwork) *http.ServeMux { } timeoutSeconds = int(timeout16) } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds)*time.Second) - defer cancel() + dialContext := func(ctx context.Context) (net.Conn, error) { + return n.DialContextTCP(ctx, addr) + } // Wait until the port is available. - for { - conn, err := n.DialContextTCP(ctx, addr) - if err == nil { - conn.Close() - logrus.Debugf("Port is available on %s", addr) - w.WriteHeader(http.StatusOK) - break - } - select { - case <-ctx.Done(): - msg := fmt.Sprintf("timed out waiting for port to become available on %s", addr) - logrus.Warn(msg) - http.Error(w, msg, http.StatusRequestTimeout) - return - default: - } - logrus.Debugf("Waiting for port to become available on %s", addr) - time.Sleep(1 * time.Second) + if err = sshutil.WaitSSHReady(r.Context(), dialContext, addr, user, timeoutSeconds); err != nil { + http.Error(w, err.Error(), http.StatusRequestTimeout) + } else { + w.WriteHeader(http.StatusOK) } }) return m diff --git a/pkg/osutil/osutil_unix.go b/pkg/osutil/osutil_unix.go index cf00ff69237..285014e1f2d 100644 --- a/pkg/osutil/osutil_unix.go +++ b/pkg/osutil/osutil_unix.go @@ -8,6 +8,7 @@ package osutil import ( "bytes" "context" + "errors" "fmt" "os" "os/exec" @@ -36,3 +37,7 @@ func Sysctl(ctx context.Context, name string) (string, error) { } return strings.TrimSuffix(string(stdout), "\n"), nil } + +func IsConnectionResetError(err error) bool { + return errors.Is(err, syscall.ECONNRESET) +} diff --git a/pkg/osutil/osutil_windows.go b/pkg/osutil/osutil_windows.go index a5ed533d988..ac27bce1e92 100644 --- a/pkg/osutil/osutil_windows.go +++ b/pkg/osutil/osutil_windows.go @@ -57,3 +57,7 @@ func SignalName(sig os.Signal) string { func Sysctl(_ context.Context, _ string) (string, error) { return "", errors.New("sysctl: unimplemented on Windows") } + +func IsConnectionResetError(err error) bool { + return errors.Is(err, syscall.WSAECONNRESET) +} diff --git a/pkg/sshutil/sshutil.go b/pkg/sshutil/sshutil.go index b60491b10e7..4331fc2e001 100644 --- a/pkg/sshutil/sshutil.go +++ b/pkg/sshutil/sshutil.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io/fs" + "net" "os" "os/exec" "path/filepath" @@ -24,6 +25,7 @@ import ( "github.com/coreos/go-semver/semver" "github.com/mattn/go-shellwords" "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" "golang.org/x/sys/cpu" "github.com/lima-vm/lima/v2/pkg/ioutilx" @@ -509,3 +511,127 @@ func detectAESAcceleration() bool { } return cpu.ARM.HasAES || cpu.ARM64.HasAES || cpu.PPC64.IsPOWER8 || cpu.S390X.HasAES || cpu.X86.HasAES } + +// WaitSSHReady waits until the SSH server is ready to accept connections. +// The dialContext function is used to create a connection to the SSH server. +// The addr, user, parameter is used for ssh.ClientConn creation. +// The timeoutSeconds parameter specifies the maximum number of seconds to wait. +func WaitSSHReady(ctx context.Context, dialContext func(context.Context) (net.Conn, error), addr, user string, timeoutSeconds int) error { + ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) + defer cancel() + + // Prepare signer + signer, err := userPrivateKeySigner() + if err != nil { + return err + } + // Prepare ssh client config + sshConfig := &ssh.ClientConfig{ + User: user, + Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)}, + HostKeyCallback: hostKeyCollector().checker(), + Timeout: 10 * time.Second, + } + // Wait until the SSH server is available. + for { + conn, err := dialContext(ctx) + if err == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, sshConfig) + if err == nil { + sshClient := ssh.NewClient(sshConn, chans, reqs) + return sshClient.Close() + } + conn.Close() + if !isRetryableError(err) { + return fmt.Errorf("failed to create ssh.Conn to %q: %w", addr, err) + } + } + logrus.Debugf("Waiting for SSH port to accept connections on %s", addr) + select { + case <-ctx.Done(): + return fmt.Errorf("failed to waiting for SSH port to become available on %s: %w", addr, ctx.Err()) + case <-time.After(1 * time.Second): + continue + } + } +} + +// errHostKeyMismatch is returned when the SSH host key does not match known hosts. +var errHostKeyMismatch = errors.New("ssh: host key mismatch") + +func isRetryableError(err error) bool { + // Port forwarder accepted the connection, but the destination is not ready yet. + return osutil.IsConnectionResetError(err) || + // SSH server not ready yet (e.g. host key not generated on initial boot). + strings.HasSuffix(err.Error(), "no supported methods remain") || + // Host key is not yet in known_hosts, but will be collected, so we can retry. + errors.Is(err, errHostKeyMismatch) +} + +// userPrivateKeySigner returns the user's private key signer. +// The public key is always installed in the VM. +func userPrivateKeySigner() (ssh.Signer, error) { + configDir, err := dirnames.LimaConfigDir() + if err != nil { + return nil, err + } + privateKeyPath := filepath.Join(configDir, filenames.UserPrivateKey) + key, err := os.ReadFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to read private key %q: %w", privateKeyPath, err) + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, fmt.Errorf("failed to parse private key %q: %w", privateKeyPath, err) + } + return signer, nil +} + +// hostKeyCollector is a singleton host key collector. +var hostKeyCollector = sync.OnceValue(func() *_hostKeyCollector { + return &_hostKeyCollector{ + hostKeys: make(map[string]ssh.PublicKey), + } +}) + +type _hostKeyCollector struct { + hostKeys map[string]ssh.PublicKey + mu sync.Mutex +} + +// checker returns a HostKeyCallback that either checks and collects the host key, +// or only checks the host key, depending on whether any host keys have been collected. +// It is expected to pass host key checks by retrying after the first collection. +// On second invocation, it will only check the host key. +func (h *_hostKeyCollector) checker() ssh.HostKeyCallback { + if len(h.hostKeys) == 0 { + return h.checkAndCollect + } + return h.checkOnly +} + +// checkAndCollect is a HostKeyCallback that records the host key provided by the SSH server. +func (h *_hostKeyCollector) checkAndCollect(_ string, _ net.Addr, key ssh.PublicKey) error { + marshaledKey := string(key.Marshal()) + h.mu.Lock() + defer h.mu.Unlock() + if _, ok := h.hostKeys[marshaledKey]; ok { + return nil + } + h.hostKeys[marshaledKey] = key + // If always returning nil here, GitHub Advanced Security may report "Use of insecure HostKeyCallback implementation". + // So, we return an error here to make the SSH client report the host key mismatch. + return errHostKeyMismatch +} + +// check is a HostKeyCallback that checks whether the host key has been collected. +func (h *_hostKeyCollector) checkOnly(_ string, _ net.Addr, key ssh.PublicKey) error { + h.mu.Lock() + defer h.mu.Unlock() + if _, ok := h.hostKeys[string(key.Marshal())]; ok { + return nil + } + // If always returning nil here, GitHub Advanced Security may report "Use of insecure HostKeyCallback implementation". + // So, we return an error here to make the SSH client report the host key mismatch. + return errHostKeyMismatch +}