Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions pkg/hostagent/hostagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ func New(ctx context.Context, instName string, stdout io.Writer, signalCh chan o
instName: instName,
instSSHAddress: inst.SSHAddress,
sshConfig: sshConfig,
portForwarder: newPortForwarder(sshConfig, sshLocalPort, rules, ignoreTCP, inst.VMType),
grpcPortForwarder: portfwd.NewPortForwarder(rules, ignoreTCP, ignoreUDP),
driver: limaDriver,
signalCh: signalCh,
Expand All @@ -254,6 +253,7 @@ func New(ctx context.Context, instName string, stdout io.Writer, signalCh chan o
guestAgentAliveCh: make(chan struct{}),
showProgress: o.showProgress,
}
a.portForwarder = newPortForwarder(sshConfig, a.sshAddressPort, rules, ignoreTCP, inst.VMType)
return a, nil
}

Expand Down Expand Up @@ -483,6 +483,12 @@ func (a *HostAgent) Info(_ context.Context) (*hostagentapi.Info, error) {
return info, nil
}

func (a *HostAgent) sshAddressPort() (sshAddress string, sshPort int) {
sshAddress = a.instSSHAddress
sshPort = a.sshLocalPort
return sshAddress, sshPort
}

func (a *HostAgent) startHostAgentRoutines(ctx context.Context) error {
if *a.instConfig.Plain {
msg := "Running in plain mode. Mounts, dynamic port forwarding, containerd, etc. will be ignored. Guest agent will not be running."
Expand Down Expand Up @@ -589,7 +595,8 @@ sudo chown -R "${USER}" /run/host-services`
}
// Copy all config files _after_ the requirements are done
for _, rule := range a.instConfig.CopyToHost {
if err := copyToHost(ctx, a.sshConfig, a.sshLocalPort, rule.HostFile, rule.GuestFile); err != nil {
sshAddress, sshPort := a.sshAddressPort()
if err := copyToHost(ctx, a.sshConfig, sshAddress, sshPort, rule.HostFile, rule.GuestFile); err != nil {
errs = append(errs, err)
}
}
Expand Down Expand Up @@ -636,10 +643,11 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
// Setup all socket forwards and defer their teardown
if !(a.driver.Info().Features.SkipSocketForwarding) {
logrus.Debugf("Forwarding unix sockets")
sshAddress, sshPort := a.sshAddressPort()
for _, rule := range a.instConfig.PortForwards {
if rule.GuestSocket != "" {
local := hostAddress(rule, &guestagentapi.IPPort{})
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbForward, rule.Reverse)
_ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, local, rule.GuestSocket, verbForward, rule.Reverse)
}
}
}
Expand All @@ -650,17 +658,18 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
a.cleanUp(func() error {
logrus.Debugf("Stop forwarding unix sockets")
var errs []error
sshAddress, sshPort := a.sshAddressPort()
for _, rule := range a.instConfig.PortForwards {
if rule.GuestSocket != "" {
local := hostAddress(rule, &guestagentapi.IPPort{})
// using ctx.Background() because ctx has already been cancelled
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, local, rule.GuestSocket, verbCancel, rule.Reverse); err != nil {
if err := forwardSSH(context.Background(), a.sshConfig, sshAddress, sshPort, local, rule.GuestSocket, verbCancel, rule.Reverse); err != nil {
errs = append(errs, err)
}
}
}
if a.driver.ForwardGuestAgent() {
if err := forwardSSH(context.Background(), a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbCancel, false); err != nil {
if err := forwardSSH(context.Background(), a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbCancel, false); err != nil {
errs = append(errs, err)
}
}
Expand All @@ -671,7 +680,8 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
if a.instConfig.MountInotify != nil && *a.instConfig.MountInotify {
if a.client == nil || !isGuestAgentSocketAccessible(ctx, a.client) {
if a.driver.ForwardGuestAgent() {
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward, false)
sshAddress, sshPort := a.sshAddressPort()
_ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbForward, false)
}
}
err := a.startInotify(ctx)
Expand All @@ -687,7 +697,8 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
for {
if a.client == nil || !isGuestAgentSocketAccessible(ctx, a.client) {
if a.driver.ForwardGuestAgent() {
_ = forwardSSH(ctx, a.sshConfig, a.sshLocalPort, localUnix, remoteUnix, verbForward, false)
sshAddress, sshPort := a.sshAddressPort()
_ = forwardSSH(ctx, a.sshConfig, sshAddress, sshPort, localUnix, remoteUnix, verbForward, false)
}
}
client, err := a.getOrCreateClient(ctx)
Expand All @@ -711,6 +722,7 @@ func (a *HostAgent) watchGuestAgentEvents(ctx context.Context) {
}

func (a *HostAgent) addStaticPortForwardsFromList(ctx context.Context, staticPortForwards []limatype.PortForward) {
sshAddress, sshPort := a.sshAddressPort()
for _, rule := range staticPortForwards {
if rule.GuestSocket == "" {
guest := &guestagentapi.IPPort{
Expand All @@ -721,7 +733,7 @@ func (a *HostAgent) addStaticPortForwardsFromList(ctx context.Context, staticPor
local, remote := a.portForwarder.forwardingAddresses(guest)
if local != "" {
logrus.Infof("Setting up static TCP forwarding from %s to %s", remote, local)
if err := forwardTCP(ctx, a.sshConfig, a.sshLocalPort, local, remote, verbForward); err != nil {
if err := forwardTCP(ctx, a.sshConfig, sshAddress, sshPort, local, remote, verbForward); err != nil {
logrus.WithError(err).Warnf("failed to set up static TCP forwarding %s -> %s", remote, local)
}
}
Expand Down Expand Up @@ -832,11 +844,11 @@ const (
verbCancel = "cancel"
)

func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, command ...string) error {
func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, command ...string) error {
args := sshConfig.Args()
args = append(args,
"-p", strconv.Itoa(port),
"127.0.0.1",
"-p", strconv.Itoa(sshPort),
sshAddress,
"--",
)
args = append(args, command...)
Expand All @@ -847,7 +859,7 @@ func executeSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, command
return nil
}

func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string, reverse bool) error {
func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string, reverse bool) error {
args := sshConfig.Args()
args = append(args,
"-T",
Expand All @@ -865,16 +877,16 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
args = append(args,
"-N",
"-f",
"-p", strconv.Itoa(port),
"127.0.0.1",
"-p", strconv.Itoa(sshPort),
sshAddress,
"--",
)
if strings.HasPrefix(local, "/") {
switch verb {
case verbForward:
if reverse {
logrus.Infof("Forwarding %q (host) to %q (guest)", local, remote)
if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil {
if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil {
logrus.WithError(err).Warnf("Failed to clean up %q (guest) before setting up forwarding", remote)
}
} else {
Expand All @@ -889,7 +901,7 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
case verbCancel:
if reverse {
logrus.Infof("Stopping forwarding %q (host) to %q (guest)", local, remote)
if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil {
if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil {
logrus.WithError(err).Warnf("Failed to clean up %q (guest) after stopping forwarding", remote)
}
} else {
Expand All @@ -910,7 +922,7 @@ func forwardSSH(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
if verb == verbForward && strings.HasPrefix(local, "/") {
if reverse {
logrus.WithError(err).Warnf("Failed to set up forward from %q (host) to %q (guest)", local, remote)
if err := executeSSH(ctx, sshConfig, port, "rm", "-f", remote); err != nil {
if err := executeSSH(ctx, sshConfig, sshAddress, sshPort, "rm", "-f", remote); err != nil {
logrus.WithError(err).Warnf("Failed to clean up %q (guest) after forwarding failed", remote)
}
} else {
Expand Down Expand Up @@ -944,10 +956,11 @@ func (a *HostAgent) watchCloudInitProgress(ctx context.Context) {
Active: true,
})

sshAddress, sshPort := a.sshAddressPort()
args := a.sshConfig.Args()
args = append(args,
"-p", strconv.Itoa(a.sshLocalPort),
"127.0.0.1",
"-p", strconv.Itoa(sshPort),
sshAddress,
"sh", "-c",
`"if command -v systemctl >/dev/null 2>&1 && systemctl is-enabled -q cloud-init-main.service; then
sudo journalctl -u cloud-init-main.service -b -S @0 -o cat -f
Expand Down Expand Up @@ -1032,8 +1045,8 @@ func (a *HostAgent) watchCloudInitProgress(ctx context.Context) {

finalArgs := a.sshConfig.Args()
finalArgs = append(finalArgs,
"-p", strconv.Itoa(a.sshLocalPort),
"127.0.0.1",
"-p", strconv.Itoa(sshPort),
sshAddress,
"sudo", "tail", "-n", "20", "/var/log/cloud-init-output.log",
)

Expand Down Expand Up @@ -1073,11 +1086,11 @@ func isDeactivatedCloudInitMainService(line string) bool {
return strings.HasPrefix(line, "cloud-init-main.service: consumed")
}

func copyToHost(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote string) error {
func copyToHost(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote string) error {
args := sshConfig.Args()
args = append(args,
"-p", strconv.Itoa(port),
"127.0.0.1",
"-p", strconv.Itoa(sshPort),
sshAddress,
"--",
)
args = append(args,
Expand Down
10 changes: 7 additions & 3 deletions pkg/hostagent/mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,16 @@ func (a *HostAgent) setupMount(ctx context.Context, m limatype.Mount) (*mount, e
}
}

sshAddress, sshPort := a.sshAddressPort()
// Create a copy of sshConfig to avoid
// modifying HostAgent's sshConfig in case of Windows
sshConfig := *a.sshConfig
rsf := &reversesshfs.ReverseSSHFS{
Driver: *m.SSHFS.SFTPDriver,
SSHConfig: a.sshConfig,
SSHConfig: &sshConfig,
LocalPath: resolvedLocation,
Host: "127.0.0.1",
Port: a.sshLocalPort,
Host: sshAddress,
Port: sshPort,
RemotePath: *m.MountPoint,
Readonly: !(*m.Writable),
SSHFSAdditionalArgs: []string{"-o", sshfsOptions},
Expand Down
27 changes: 14 additions & 13 deletions pkg/hostagent/port.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ import (
)

type portForwarder struct {
sshConfig *ssh.SSHConfig
sshHostPort int
rules []limatype.PortForward
ignore bool
vmType limatype.VMType
sshConfig *ssh.SSHConfig
sshAddressPort func() (string, int)
rules []limatype.PortForward
ignore bool
vmType limatype.VMType
}

const sshGuestPort = 22

var IPv4loopback1 = limayaml.IPv4loopback1

func newPortForwarder(sshConfig *ssh.SSHConfig, sshHostPort int, rules []limatype.PortForward, ignore bool, vmType limatype.VMType) *portForwarder {
func newPortForwarder(sshConfig *ssh.SSHConfig, sshAddressPort func() (string, int), rules []limatype.PortForward, ignore bool, vmType limatype.VMType) *portForwarder {
return &portForwarder{
sshConfig: sshConfig,
sshHostPort: sshHostPort,
rules: rules,
ignore: ignore,
vmType: vmType,
sshConfig: sshConfig,
sshAddressPort: sshAddressPort,
rules: rules,
ignore: ignore,
vmType: vmType,
}
}

Expand Down Expand Up @@ -87,6 +87,7 @@ func (pf *portForwarder) forwardingAddresses(guest *api.IPPort) (hostAddr, guest
}

func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) {
sshAddress, sshPort := pf.sshAddressPort()
for _, f := range ev.RemovedLocalPorts {
if f.Protocol != "tcp" {
continue
Expand All @@ -96,7 +97,7 @@ func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) {
continue
}
logrus.Infof("Stopping forwarding TCP from %s to %s", remote, local)
if err := forwardTCP(ctx, pf.sshConfig, pf.sshHostPort, local, remote, verbCancel); err != nil {
if err := forwardTCP(ctx, pf.sshConfig, sshAddress, sshPort, local, remote, verbCancel); err != nil {
logrus.WithError(err).Warnf("failed to stop forwarding tcp port %d", f.Port)
}
}
Expand All @@ -112,7 +113,7 @@ func (pf *portForwarder) OnEvent(ctx context.Context, ev *api.Event) {
continue
}
logrus.Infof("Forwarding TCP from %s to %s", remote, local)
if err := forwardTCP(ctx, pf.sshConfig, pf.sshHostPort, local, remote, verbForward); err != nil {
if err := forwardTCP(ctx, pf.sshConfig, sshAddress, sshPort, local, remote, verbForward); err != nil {
logrus.WithError(err).Warnf("failed to set up forwarding tcp port %d (negligible if already forwarded)", f.Port)
}
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/hostagent/port_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
)

// forwardTCP is not thread-safe.
func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error {
func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error {
if strings.HasPrefix(local, "/") {
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false)
}
localIPStr, localPortStr, err := net.SplitHostPort(local)
if err != nil {
Expand All @@ -35,7 +35,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
}

if !localIP.Equal(IPv4loopback1) || localPort >= 1024 {
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false)
}

// on macOS, listening on 127.0.0.1:80 requires root while 0.0.0.0:80 does not require root.
Expand All @@ -50,7 +50,7 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
localUnix := plf.unixAddr.Name
_ = plf.Close()
delete(pseudoLoopbackForwarders, local)
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil {
if err := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verb, false); err != nil {
return err
}
} else {
Expand All @@ -65,12 +65,12 @@ func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local,
}
localUnix := filepath.Join(localUnixDir, "sock")
logrus.Debugf("forwarding %q to %q", localUnix, remote)
if err := forwardSSH(ctx, sshConfig, port, localUnix, remote, verb, false); err != nil {
if err := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verb, false); err != nil {
return err
}
plf, err := newPseudoLoopbackForwarder(localPort, localUnix)
if err != nil {
if cancelErr := forwardSSH(ctx, sshConfig, port, localUnix, remote, verbCancel, false); cancelErr != nil {
if cancelErr := forwardSSH(ctx, sshConfig, sshAddress, sshPort, localUnix, remote, verbCancel, false); cancelErr != nil {
logrus.WithError(cancelErr).Warnf("failed to cancel forwarding %q to %q", localUnix, remote)
}
return err
Expand Down
4 changes: 2 additions & 2 deletions pkg/hostagent/port_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ import (
"github.com/lima-vm/sshocker/pkg/ssh"
)

func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error {
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error {
return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false)
}
4 changes: 2 additions & 2 deletions pkg/hostagent/port_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ import (
"github.com/lima-vm/sshocker/pkg/ssh"
)

func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, port int, local, remote, verb string) error {
return forwardSSH(ctx, sshConfig, port, local, remote, verb, false)
func forwardTCP(ctx context.Context, sshConfig *ssh.SSHConfig, sshAddress string, sshPort int, local, remote, verb string) error {
return forwardSSH(ctx, sshConfig, sshAddress, sshPort, local, remote, verb, false)
}
4 changes: 3 additions & 1 deletion pkg/sshutil/sshutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ func SSHArgsFromOpts(opts []string) []string {

// SSHOptsRemovingControlPath removes ControlMaster, ControlPath, and ControlPersist options from SSH options.
func SSHOptsRemovingControlPath(opts []string) []string {
return slices.DeleteFunc(opts, func(s string) bool {
// Create a copy of opts to avoid modifying the original slice, since slices.DeleteFunc modifies the slice in place.
copiedOpts := slices.Clone(opts)
return slices.DeleteFunc(copiedOpts, func(s string) bool {
return strings.HasPrefix(s, "ControlMaster") || strings.HasPrefix(s, "ControlPath") || strings.HasPrefix(s, "ControlPersist")
})
}
Expand Down