Skip to content
This repository has been archived by the owner on Jan 17, 2021. It is now read-only.

Commit

Permalink
restructure SSH master code, apply requested fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
deansheather committed Jun 27, 2019
1 parent eee34f5 commit 3794755
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 82 deletions.
10 changes: 5 additions & 5 deletions main.go
Expand Up @@ -78,11 +78,11 @@ func (c *rootCmd) Run(fl *flag.FlagSet) {
}

err := sshCode(host, dir, options{
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
noReuseConnection: c.noReuseConnection,
skipSync: c.skipSync,
sshFlags: c.sshFlags,
bindAddr: c.bindAddr,
syncBack: c.syncBack,
reuseConnection: !c.noReuseConnection,
})

if err != nil {
Expand Down
172 changes: 95 additions & 77 deletions sshcode.go
Expand Up @@ -21,18 +21,21 @@ import (
)

const codeServerPath = "~/.cache/sshcode/sshcode-server"
const sshDirectory = "~/.ssh"
const sshDirectoryUnsafeModeMask = 0022
const sshControlPath = sshDirectory + "/control-%h-%p-%r"

const (
sshDirectory = "~/.ssh"
sshDirectoryUnsafeModeMask = 0022
sshControlPath = sshDirectory + "/control-%h-%p-%r"
)

type options struct {
skipSync bool
syncBack bool
noOpen bool
noReuseConnection bool
bindAddr string
remotePort string
sshFlags string
skipSync bool
syncBack bool
noOpen bool
reuseConnection bool
bindAddr string
remotePort string
sshFlags string
}

func sshCode(host, dir string, o options) error {
Expand All @@ -57,68 +60,19 @@ func sshCode(host, dir string, o options) error {
}

// Check the SSH directory's permissions and warn the user if it is not safe.
sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
if err != nil {
if !o.noReuseConnection {
flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
o.noReuseConnection = true
}
} else {
if !sshDirectoryMode.IsDir() {
if !o.noReuseConnection {
flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
o.noReuseConnection = true
} else {
flog.Info("warning: %v is not a directory", sshDirectory)
}
}
if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
"the owner (and files inside should be set to 0600)", sshDirectory)
}
}
o.reuseConnection = checkSSHDirectory(sshDirectory, o.reuseConnection)

// Start SSH master connection socket. This prevents multiple password prompts from appearing as authentication
// only happens on the initial connection.
if !o.noReuseConnection {
newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, o.sshFlags, sshControlPath)

// -MN means "start a master socket and don't open a session, just connect".
sshCmdStr := fmt.Sprintf(`exec ssh %v -MN %v`, newSSHFlags, host)
sshMasterCmd := exec.Command("sh", "-c", sshCmdStr)
sshMasterCmd.Stdin = os.Stdin
sshMasterCmd.Stdout = os.Stdout
sshMasterCmd.Stderr = os.Stderr
stopSSHMaster := func() {
if sshMasterCmd.Process != nil {
err := sshMasterCmd.Process.Signal(syscall.Signal(0))
if err != nil {
return
}
err = sshMasterCmd.Process.Signal(syscall.SIGTERM)
if err != nil {
flog.Error("failed to send SIGTERM to SSH master process: %v", err)
}
}
}
defer stopSSHMaster()

err = sshMasterCmd.Start()
go sshMasterCmd.Wait()
if o.reuseConnection {
flog.Info("starting SSH master connection...")
newSSHFlags, cancel, err := startSSHMaster(o.sshFlags, sshControlPath, host)
defer cancel()
if err != nil {
flog.Error("failed to start SSH master connection, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
stopSSHMaster()
flog.Error("failed to start SSH master connection: %v", err)
o.reuseConnection = false
} else {
err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
if err != nil {
flog.Error("SSH master failed to be ready in time, disabling connection reuse feature: %v", err)
o.noReuseConnection = true
stopSSHMaster()
} else {
sshMasterCmd.Stdin = nil
o.sshFlags = newSSHFlags
}
o.sshFlags = newSSHFlags
}
}

Expand Down Expand Up @@ -226,12 +180,12 @@ func sshCode(host, dir string, o options) error {

err = syncExtensions(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync extensions back: %v", err)
return xerrors.Errorf("failed to sync extensions back: %w", err)
}

err = syncUserSettings(o.sshFlags, host, true)
if err != nil {
return xerrors.Errorf("failed to sync user settings settings back: %v", err)
return xerrors.Errorf("failed to sync user settings settings back: %w", err)
}

return nil
Expand Down Expand Up @@ -350,6 +304,74 @@ func randomPort() (string, error) {
return "", xerrors.Errorf("max number of tries exceeded: %d", maxTries)
}

// checkSSHDirectory performs sanity and safety checks on sshDirectory, and
// returns a new value for o.reuseConnection depending on the checks.
func checkSSHDirectory(sshDirectory string, reuseConnection bool) bool {
sshDirectoryMode, err := os.Lstat(expandPath(sshDirectory))
if err != nil {
if reuseConnection {
flog.Info("failed to stat %v directory, disabling connection reuse feature: %v", sshDirectory, err)
}
reuseConnection = false
} else {
if !sshDirectoryMode.IsDir() {
if reuseConnection {
flog.Info("%v is not a directory, disabling connection reuse feature", sshDirectory)
} else {
flog.Info("warning: %v is not a directory", sshDirectory)
}
reuseConnection = false
}
if sshDirectoryMode.Mode().Perm()&sshDirectoryUnsafeModeMask != 0 {
flog.Info("warning: the %v directory has unsafe permissions, they should only be writable by "+
"the owner (and files inside should be set to 0600)", sshDirectory)
}
}
return reuseConnection
}

// startSSHMaster starts an SSH master connection and waits for it to be ready.
// It returns a new set of SSH flags for child SSH processes to use.
func startSSHMaster(sshFlags string, sshControlPath string, host string) (string, func(), error) {
ctx, cancel := context.WithCancel(context.Background())

newSSHFlags := fmt.Sprintf(`%v -o "ControlPath=%v"`, sshFlags, sshControlPath)

// -MN means "start a master socket and don't open a session, just connect".
sshCmdStr := fmt.Sprintf(`exec ssh %v -MNq %v`, newSSHFlags, host)
sshMasterCmd := exec.CommandContext(ctx, "sh", "-c", sshCmdStr)
sshMasterCmd.Stdin = os.Stdin
sshMasterCmd.Stderr = os.Stderr

// Gracefully stop the SSH master.
stopSSHMaster := func() {
if sshMasterCmd.Process != nil {
if sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited() {
return
}
err := sshMasterCmd.Process.Signal(syscall.SIGTERM)
if err != nil {
flog.Error("failed to send SIGTERM to SSH master process: %v", err)
}
}
cancel()
}

// Start ssh master and wait. Waiting prevents the process from becoming a zombie process if it dies before
// sshcode does, and allows sshMasterCmd.ProcessState to be populated.
err := sshMasterCmd.Start()
go sshMasterCmd.Wait()
if err != nil {
return "", stopSSHMaster, err
}
err = checkSSHMaster(sshMasterCmd, newSSHFlags, host)
if err != nil {
stopSSHMaster()
return "", stopSSHMaster, xerrors.Errorf("SSH master wasn't ready on time: %w", err)
}
return newSSHFlags, stopSSHMaster, nil
}

// checkSSHMaster polls every second for 30 seconds to check if the SSH master
// is ready.
func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error {
Expand All @@ -359,16 +381,12 @@ func checkSSHMaster(sshMasterCmd *exec.Cmd, sshFlags string, host string) error
err error
)
for i := 0; i < maxTries; i++ {
// Check if the master is running
if sshMasterCmd.Process == nil {
return xerrors.Errorf("SSH master process not running")
}
err = sshMasterCmd.Process.Signal(syscall.Signal(0))
if err != nil {
return xerrors.Errorf("failed to check if SSH master process was alive: %v", err)
// Check if the master is running.
if sshMasterCmd.Process == nil || (sshMasterCmd.ProcessState != nil && sshMasterCmd.ProcessState.Exited()) {
return xerrors.Errorf("SSH master process is not running")
}

// Check if it's ready
// Check if it's ready.
sshCmdStr := fmt.Sprintf(`ssh %v -O check %v`, sshFlags, host)
sshCmd := exec.Command("sh", "-c", sshCmdStr)
err = sshCmd.Run()
Expand Down

0 comments on commit 3794755

Please sign in to comment.