diff --git a/pkg/hostagent/hostagent.go b/pkg/hostagent/hostagent.go index 53ba5fd1039..5172d0ccb18 100644 --- a/pkg/hostagent/hostagent.go +++ b/pkg/hostagent/hostagent.go @@ -244,7 +244,6 @@ func New(ctx context.Context, instName string, stdout io.Writer, signalCh chan o instName: instName, instSSHAddress: inst.SSHAddress, sshConfig: sshConfig, - portForwarder: newPortForwarder(sshConfig, sshLocalPort, rules, ignoreTCP, inst.VMType), grpcPortForwarder: portfwd.NewPortForwarder(rules, ignoreTCP, ignoreUDP), driver: limaDriver, signalCh: signalCh, @@ -254,6 +253,7 @@ func New(ctx context.Context, instName string, stdout io.Writer, signalCh chan o guestAgentAliveCh: make(chan struct{}), showProgress: o.showProgress, } + a.portForwarder = newPortForwarder(sshConfig, a.sshAddressPort, rules, ignoreTCP, inst.VMType) return a, nil } @@ -483,6 +483,12 @@ func (a *HostAgent) Info(_ context.Context) (*hostagentapi.Info, error) { return info, nil } +func (a *HostAgent) sshAddressPort() (sshAddress string, sshPort int) { + sshAddress = a.instSSHAddress + sshPort = a.sshLocalPort + return sshAddress, sshPort +} + func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error { if *a.instConfig.Plain { msg := "Running in plain mode. Mounts, dynamic port forwarding, containerd, etc. will be ignored. Guest agent will not be running." @@ -589,7 +595,8 @@ sudo chown -R "${USER}" /run/host-services` } // Copy all config files _after_ the requirements are done for _, rule := range a.instConfig.CopyToHost { - if err := copyToHost(ctx, a.sshConfig, a.sshLocalPort, rule.HostFile, rule.GuestFile); err != nil { + sshAddress, sshPort := a.sshAddressPort() + if err := copyToHost(ctx, a.sshConfig, sshAddress, sshPort, rule.HostFile, rule.GuestFile); err != nil { errs = append(errs, err) } } @@ -636,10 +643,11 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) { // Setup all socket forwards and defer their teardown if !(a.driver.Info().Features.SkipSocketForwarding) { logrus.Debugf("Forwarding unix sockets") + sshAddress, sshPort := a.sshAddressPort() for _, rule := range a.instConfig.PortForwards { if rule.GuestSocket != "" { local := hostAddress(rule, &guestagentapi.IPPort{}) - _ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbForward, rule.Reverse) + _ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, local, rule.GuestSocket, verbForward, rule.Reverse) } } } @@ -650,17 +658,18 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) { a.cleanUp(func() error { logrus.Debugf("Stop forwarding unix sockets") var errs []error + sshAddress, sshPort := a.sshAddressPort() for _, rule := range a.instConfig.PortForwards { if rule.GuestSocket != "" { local := hostAddress(rule, &guestagentapi.IPPort{}) // using ctx.Background() because ctx has already been cancelled - if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbCancel, rule.Reverse); err != nil { + if err := forwardSSH(context.Background(), a.sshConfig, sshAddress, sshPort, local, rule.GuestSocket, verbCancel, rule.Reverse); err != nil { errs = append(errs, err) } } } if a.driver.ForwardGuestAgent() { - if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbCancel, false); err != nil { + if err := forwardSSH(context.Background(), a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbCancel, false); err != nil { errs = append(errs, err) } } @@ -671,7 +680,8 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) { if a.instConfig.MountInotify != nil && *a.instConfig.MountInotify { if a.client == nil || !isGuestAgentSocketAccessible(ctx, a.client) { if a.driver.ForwardGuestAgent() { - _ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward, false) + sshAddress, sshPort := a.sshAddressPort() + _ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbForward, false) } } err := a.startInotify(ctx) @@ -687,7 +697,8 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) { for { if a.client == nil || !isGuestAgentSocketAccessible(ctx, a.client) { if a.driver.ForwardGuestAgent() { - _ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward, false) + sshAddress, sshPort := a.sshAddressPort() + _ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbForward, false) } } client, err := a.getOrCreateClient(ctx) @@ -711,6 +722,7 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) { } func (a *HostAgent) addStaticPortForwardsFromList(ctx context.Context, staticPortForwards []limatype.PortForward) { + sshAddress, sshPort := a.sshAddressPort() for _, rule := range staticPortForwards { if rule.GuestSocket == "" { guest := &guestagentapi.IPPort{ @@ -721,7 +733,7 @@ func (a *HostAgent) addStaticPortForwardsFromList(ctx context.Context, staticPor local, remote := a.portForwarder.forwardingAddresses(guest) if local != "" { logrus.Infof("Setting up static TCP forwarding from %s to %s", remote, local) - if err := forwardTCP(ctx, a.sshConfig, a.sshLocalPort, local, remote, verbForward); err != nil { + if err := forwardTCP(ctx, a.sshConfig, sshAddress, sshPort, local, remote, verbForward); err != nil { logrus.WithError(err).Warnf("failed to set up static TCP forwarding %s -> %s", remote, local) } } @@ -832,11 +844,11 @@ const ( verbCancel = "cancel" ) -func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, command ...string) error { +func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, command ...string) error { args := sshConfig.Args() args = append(args, - "-p", strconv.Itoa(port), - "127.0.0.1", + "-p", strconv.Itoa(sshPort), + sshAddress, "--", ) args = append(args, command...) @@ -847,7 +859,7 @@ func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, command return nil } -func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string, reverse bool) error { +func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string, reverse bool) error { args := sshConfig.Args() args = append(args, "-T", @@ -865,8 +877,8 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, args = append(args, "-N", "-f", - "-p", strconv.Itoa(port), - "127.0.0.1", + "-p", strconv.Itoa(sshPort), + sshAddress, "--", ) if strings.HasPrefix(local, "/") { @@ -874,7 +886,7 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, case verbForward: if reverse { logrus.Infof("Forwarding %q (host) to %q (guest)", local, remote) - if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil { + if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil { logrus.WithError(err).Warnf("Failed to clean up %q (guest) before setting up forwarding", remote) } } else { @@ -889,7 +901,7 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, case verbCancel: if reverse { logrus.Infof("Stopping forwarding %q (host) to %q (guest)", local, remote) - if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil { + if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil { logrus.WithError(err).Warnf("Failed to clean up %q (guest) after stopping forwarding", remote) } } else { @@ -910,7 +922,7 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, if verb == verbForward && strings.HasPrefix(local, "/") { if reverse { logrus.WithError(err).Warnf("Failed to set up forward from %q (host) to %q (guest)", local, remote) - if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil { + if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil { logrus.WithError(err).Warnf("Failed to clean up %q (guest) after forwarding failed", remote) } } else { @@ -944,10 +956,11 @@ func (a *HostAgent) watchCloudInitProgress(ctx context.Context) { Active: true, }) + sshAddress, sshPort := a.sshAddressPort() args := a.sshConfig.Args() args = append(args, - "-p", strconv.Itoa(a.sshLocalPort), - "127.0.0.1", + "-p", strconv.Itoa(sshPort), + sshAddress, "sh", "-c", `"if command -v systemctl >/dev/null 2>&1 && systemctl is-enabled -q cloud-init-main.service; then sudo journalctl -u cloud-init-main.service -b -S @0 -o cat -f @@ -1032,8 +1045,8 @@ func (a *HostAgent) watchCloudInitProgress(ctx context.Context) { finalArgs := a.sshConfig.Args() finalArgs = append(finalArgs, - "-p", strconv.Itoa(a.sshLocalPort), - "127.0.0.1", + "-p", strconv.Itoa(sshPort), + sshAddress, "sudo", "tail", "-n", "20", "/var/log/cloud-init-output.log", ) @@ -1073,11 +1086,11 @@ func isDeactivatedCloudInitMainService(line string) bool { return strings.HasPrefix(line, "cloud-init-main.service: consumed") } -func copyToHost(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string) error { +func copyToHost(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote string) error { args := sshConfig.Args() args = append(args, - "-p", strconv.Itoa(port), - "127.0.0.1", + "-p", strconv.Itoa(sshPort), + sshAddress, "--", ) args = append(args, diff --git a/pkg/hostagent/mount.go b/pkg/hostagent/mount.go index 412afab1c44..f9fff74a0c8 100644 --- a/pkg/hostagent/mount.go +++ b/pkg/hostagent/mount.go @@ -61,12 +61,16 @@ func (a *HostAgent) setupMount(ctx context.Context, m limatype.Mount) (*mount, e } } + sshAddress, sshPort := a.sshAddressPort() + // Create a copy of sshConfig to avoid + // modifying HostAgent's sshConfig in case of Windows + sshConfig := *a.sshConfig rsf := &reversesshfs.ReverseSSHFS{ Driver: *m.SSHFS.SFTPDriver, - SSHConfig: a.sshConfig, + SSHConfig: &sshConfig, LocalPath: resolvedLocation, - Host: "127.0.0.1", - Port: a.sshLocalPort, + Host: sshAddress, + Port: sshPort, RemotePath: *m.MountPoint, Readonly: !(*m.Writable), SSHFSAdditionalArgs: []string{"-o", sshfsOptions}, diff --git a/pkg/hostagent/port.go b/pkg/hostagent/port.go index 1186a397e42..6ecc92e940d 100644 --- a/pkg/hostagent/port.go +++ b/pkg/hostagent/port.go @@ -16,24 +16,24 @@ import ( ) type portForwarder struct { - sshConfig *ssh.SSHConfig - sshHostPort int - rules []limatype.PortForward - ignore bool - vmType limatype.VMType + sshConfig *ssh.SSHConfig + sshAddressPort func() (string, int) + rules []limatype.PortForward + ignore bool + vmType limatype.VMType } const sshGuestPort = 22 var IPv4loopback1 = limayaml.IPv4loopback1 -func newPortForwarder(sshConfig *ssh.SSHConfig, sshHostPort int, rules []limatype.PortForward, ignore bool, vmType limatype.VMType) *portForwarder { +func newPortForwarder(sshConfig *ssh.SSHConfig, sshAddressPort func() (string, int), rules []limatype.PortForward, ignore bool, vmType limatype.VMType) *portForwarder { return &portForwarder{ - sshConfig: sshConfig, - sshHostPort: sshHostPort, - rules: rules, - ignore: ignore, - vmType: vmType, + sshConfig: sshConfig, + sshAddressPort: sshAddressPort, + rules: rules, + ignore: ignore, + vmType: vmType, } } @@ -87,6 +87,7 @@ func (pf *portForwarder) forwardingAddresses(guest *api.IPPort) (hostAddr, guest } func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) { + sshAddress, sshPort := pf.sshAddressPort() for _, f := range ev.RemovedLocalPorts { if f.Protocol != "tcp" { continue @@ -96,7 +97,7 @@ func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) { continue } logrus.Infof("Stopping forwarding TCP from %s to %s", remote, local) - if err := forwardTCP(ctx, pf.sshConfig, pf.sshHostPort, local, remote, verbCancel); err != nil { + if err := forwardTCP(ctx, pf.sshConfig, sshAddress, sshPort, local, remote, verbCancel); err != nil { logrus.WithError(err).Warnf("failed to stop forwarding tcp port %d", f.Port) } } @@ -112,7 +113,7 @@ func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) { continue } logrus.Infof("Forwarding TCP from %s to %s", remote, local) - if err := forwardTCP(ctx, pf.sshConfig, pf.sshHostPort, local, remote, verbForward); err != nil { + if err := forwardTCP(ctx, pf.sshConfig, sshAddress, sshPort, local, remote, verbForward); err != nil { logrus.WithError(err).Warnf("failed to set up forwarding tcp port %d (negligible if already forwarded)", f.Port) } } diff --git a/pkg/hostagent/port_darwin.go b/pkg/hostagent/port_darwin.go index 71737ee71f1..35c065c0998 100644 --- a/pkg/hostagent/port_darwin.go +++ b/pkg/hostagent/port_darwin.go @@ -20,9 +20,9 @@ import ( ) // forwardTCP is not thread-safe. -func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error { +func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error { if strings.HasPrefix(local, "/") { - return forwardSSH(ctx, sshConfig, port, local, remote, verb, false) + return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false) } localIPStr, localPortStr, err := net.SplitHostPort(local) if err != nil { @@ -35,7 +35,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, } if !localIP.Equal(IPv4loopback1) || localPort >= 1024 { - return forwardSSH(ctx, sshConfig, port, local, remote, verb, false) + return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false) } // on macOS, listening on 127.0.0.1:80 requires root while 0.0.0.0:80 does not require root. @@ -50,7 +50,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, localUnix := plf.unixAddr.Name _ = plf.Close() delete(pseudoLoopbackForwarders, local) - if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil { + if err := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verb, false); err != nil { return err } } else { @@ -65,12 +65,12 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, } localUnix := filepath.Join(localUnixDir, "sock") logrus.Debugf("forwarding %q to %q", localUnix, remote) - if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil { + if err := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verb, false); err != nil { return err } plf, err := newPseudoLoopbackForwarder(localPort, localUnix) if err != nil { - if cancelErr := forwardSSH(ctx, sshConfig, port, localUnix, remote, verbCancel, false); cancelErr != nil { + if cancelErr := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verbCancel, false); cancelErr != nil { logrus.WithError(cancelErr).Warnf("failed to cancel forwarding %q to %q", localUnix, remote) } return err diff --git a/pkg/hostagent/port_others.go b/pkg/hostagent/port_others.go index 8d218c25b35..ad5ced84d06 100644 --- a/pkg/hostagent/port_others.go +++ b/pkg/hostagent/port_others.go @@ -11,6 +11,6 @@ import ( "github.com/lima-vm/sshocker/pkg/ssh" ) -func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error { - return forwardSSH(ctx, sshConfig, port, local, remote, verb, false) +func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error { + return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false) } diff --git a/pkg/hostagent/port_windows.go b/pkg/hostagent/port_windows.go index d8d19f0cbd1..7c0d235fd26 100644 --- a/pkg/hostagent/port_windows.go +++ b/pkg/hostagent/port_windows.go @@ -9,6 +9,6 @@ import ( "github.com/lima-vm/sshocker/pkg/ssh" ) -func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error { - return forwardSSH(ctx, sshConfig, port, local, remote, verb, false) +func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error { + return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false) } diff --git a/pkg/sshutil/sshutil.go b/pkg/sshutil/sshutil.go index f2b12665fa8..e53bfa2f1da 100644 --- a/pkg/sshutil/sshutil.go +++ b/pkg/sshutil/sshutil.go @@ -373,7 +373,9 @@ func SSHArgsFromOpts(opts []string) []string { // SSHOptsRemovingControlPath removes ControlMaster, ControlPath, and ControlPersist options from SSH options. func SSHOptsRemovingControlPath(opts []string) []string { - return slices.DeleteFunc(opts, func(s string) bool { + // Create a copy of opts to avoid modifying the original slice, since slices.DeleteFunc modifies the slice in place. + copiedOpts := slices.Clone(opts) + return slices.DeleteFunc(copiedOpts, func(s string) bool { return strings.HasPrefix(s, "ControlMaster") || strings.HasPrefix(s, "ControlPath") || strings.HasPrefix(s, "ControlPersist") }) }