Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/fosrl/cli/cmd/down"
"github.com/fosrl/cli/cmd/list"
"github.com/fosrl/cli/cmd/logs"
"github.com/fosrl/cli/cmd/scp"
selectcmd "github.com/fosrl/cli/cmd/select"
"github.com/fosrl/cli/cmd/ssh"
"github.com/fosrl/cli/cmd/status"
Expand Down Expand Up @@ -67,6 +68,7 @@ func RootCommand(initResources bool) (*cobra.Command, error) {
}

cmd.AddCommand(ssh.SSHCmd())
cmd.AddCommand(scp.SCPCmd())
cmd.AddCommand(update.UpdateCmd())
cmd.AddCommand(version.VersionCmd())
cmd.AddCommand(login.LoginCmd())
Expand Down
123 changes: 123 additions & 0 deletions cmd/scp/connect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package scp

import (
"fmt"
"time"

"github.com/charmbracelet/bubbles/spinner"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/fosrl/cli/internal/olm"
)

const (
siteAppearTimeout = 15 * time.Second
siteConnectTimeout = 30 * time.Second
pollInterval = 500 * time.Millisecond
)

// siteConnectedMsg is sent to the bubbletea program when any site connects.
type siteConnectedMsg struct{}

// siteConnectTimedOutMsg is sent when the connection poll deadline is exceeded.
type siteConnectTimedOutMsg struct{}

// connectSpinnerModel is a minimal bubbletea model that displays a spinner
// while a background goroutine polls for the site connection.
type connectSpinnerModel struct {
spinner spinner.Model
timedOut bool
}

func newConnectSpinnerModel() connectSpinnerModel {
s := spinner.New()
s.Spinner = spinner.Dot
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("6")) // cyan
return connectSpinnerModel{spinner: s}
}

func (m connectSpinnerModel) Init() tea.Cmd {
return m.spinner.Tick
}

func (m connectSpinnerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg.(type) {
case siteConnectedMsg:
return m, tea.Quit
case siteConnectTimedOutMsg:
m.timedOut = true
return m, tea.Quit
}
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
return m, cmd
}

func (m connectSpinnerModel) View() string {
return fmt.Sprintf("%s Connecting...\n", m.spinner.View())
}

// waitForAnySiteConnection waits for at least one site from siteIDs to appear
// in the olm status output and become connected.
func waitForAnySiteConnection(client *olm.Client, siteIDs []int) error {
deadline := time.Now().Add(siteAppearTimeout)
appearedIDs := map[int]bool{}
anyConnected := false

for time.Now().Before(deadline) {
status, err := client.GetStatus()
if err == nil {
for _, siteID := range siteIDs {
if peer, ok := status.PeerStatuses[siteID]; ok {
appearedIDs[siteID] = true
if peer.Connected {
anyConnected = true
}
}
}
}
if len(appearedIDs) > 0 {
break
}
time.Sleep(pollInterval)
}

if len(appearedIDs) == 0 {
return fmt.Errorf("no sites were added to the connection; the JIT connect request may have failed")
}

if anyConnected {
return nil
}

model := newConnectSpinnerModel()
program := tea.NewProgram(model)

go func() {
deadline := time.Now().Add(siteConnectTimeout)
for time.Now().Before(deadline) {
status, err := client.GetStatus()
if err == nil {
for siteID := range appearedIDs {
if peer, ok := status.PeerStatuses[siteID]; ok && peer.Connected {
program.Send(siteConnectedMsg{})
return
}
}
}
time.Sleep(pollInterval)
}
program.Send(siteConnectTimedOutMsg{})
}()

finalModel, err := program.Run()
if err != nil {
return fmt.Errorf("spinner error: %w", err)
}

if finalModel.(connectSpinnerModel).timedOut {
return fmt.Errorf("Timed out waiting for site to connect. Please disconnect (down) then reconnect (up) the client and try again.")
}

return nil
}
94 changes: 94 additions & 0 deletions cmd/scp/exec_args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package scp

import (
"fmt"
"runtime"
"strconv"
"strings"
)

func buildExecSCPArgs(scpPath string, opts RunOpts, keyPath, certPath string) []string {
args := []string{scpPath}
if keyPath != "" {
args = append(args, "-i", keyPath)
}
if certPath != "" {
args = append(args, "-o", "CertificateFile="+certPath)
}
args = append(args,
"-o", "PubkeyAuthentication=yes",
"-o", "PreferredAuthentications=publickey",
"-o", "IdentitiesOnly=yes",
"-o", "PasswordAuthentication=no",
"-o", "KbdInteractiveAuthentication=no",
)
args = append(args, "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "LogLevel=ERROR")
if opts.Port > 0 {
args = append(args, "-P", strconv.Itoa(opts.Port))
}
args = append(args, opts.Passthrough.Options...)
args = append(args, rewriteSCPOperands(opts)...)
return args
}

func rewriteSCPOperands(opts RunOpts) []string {
if len(opts.Passthrough.RemoteCommand) == 0 {
return nil
}
rewritten := make([]string, 0, len(opts.Passthrough.RemoteCommand))
for _, operand := range opts.Passthrough.RemoteCommand {
rewritten = append(rewritten, rewriteSCPOperand(operand, opts.ResourceID, opts.User, opts.Hostname))
}
return rewritten
}

func rewriteSCPOperand(operand, resourceID, user, hostname string) string {
hostSpec, pathPart, ok := splitSCPOperand(operand)
if !ok {
return operand
}
if !matchesTargetHost(hostSpec, resourceID) {
return operand
}
return fmt.Sprintf("%s:%s", hostWithUser(hostname, user), pathPart)
}

func splitSCPOperand(s string) (hostSpec string, pathPart string, ok bool) {
if s == "" {
return "", "", false
}
if runtime.GOOS == "windows" {
if len(s) >= 2 && s[1] == ':' {
if len(s) == 2 {
return "", "", false
}
next := s[2]
if next == '\\' || next == '/' {
return "", "", false
}
}
}
idx := strings.IndexByte(s, ':')
if idx <= 0 || idx == len(s)-1 {
return "", "", false
}
return s[:idx], s[idx+1:], true
}

func matchesTargetHost(hostSpec, resourceID string) bool {
hostOnly := hostSpec
if u, h, hasAt := strings.Cut(hostSpec, "@"); hasAt {
if u == "" || h == "" {
return false
}
hostOnly = h
}
return hostOnly == resourceID
}

func hostWithUser(hostname, user string) string {
if user == "" {
return hostname
}
return user + "@" + hostname
}
17 changes: 17 additions & 0 deletions cmd/scp/exec_scp_env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package scp

import (
"os"
"strings"
)

// envSCPBinary overrides the scp(1) executable used by RunExec on all platforms when non-empty.
const envSCPBinary = "PANGOLIN_SCP_BINARY"

func scpBinaryFromEnv() (path string, ok bool) {
p := strings.TrimSpace(os.Getenv(envSCPBinary))
if p == "" {
return "", false
}
return p, true
}
136 changes: 136 additions & 0 deletions cmd/scp/runner_exec_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//go:build !windows
// +build !windows

package scp

import (
"errors"
"fmt"
"os"
"os/exec"
)

// execSCPSearchPaths are fallback locations for the scp executable when not in PATH.
var execSCPSearchPaths = []string{
"/usr/bin/scp",
"/usr/local/bin/scp",
`C:\\Windows\\System32\\OpenSSH\\scp.exe`,
}

func findExecSCPPath() (string, error) {
if p, ok := scpBinaryFromEnv(); ok {
if isExecutable(p) {
return p, nil
}
return "", fmt.Errorf("%s=%q: not an executable file", envSCPBinary, p)
}
if path, err := exec.LookPath("scp"); err == nil {
return path, nil
}
for _, p := range execSCPSearchPaths {
if isExecutable(p) {
return p, nil
}
}
return "", errors.New("scp executable not found in PATH or in common locations")
}

func isExecutable(path string) bool {
info, err := os.Stat(path)
if err != nil || info.IsDir() {
return false
}
return info.Mode()&0o111 != 0
}

func execExitCode(err error) int {
if err == nil {
return 0
}
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
}
return 1
}

// RunExec runs scp via the system scp binary. opts.PrivateKeyPEM and opts.Certificate
// must be set (JIT key + signed cert).
func RunExec(opts RunOpts) (int, error) {
scpPath, err := findExecSCPPath()
if err != nil {
return 1, err
}

keyPath, certPath, cleanup, err := writeExecKeyFiles(opts)
if err != nil {
return 1, err
}
if cleanup != nil {
defer cleanup()
}

argv := buildExecSCPArgs(scpPath, opts, keyPath, certPath)
cmd := exec.Command(argv[0], argv[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return execExitCode(err), nil
}
return 0, nil
}

func writeExecKeyFiles(opts RunOpts) (keyPath, certPath string, cleanup func(), err error) {
if opts.PrivateKeyPEM == "" {
return "", "", nil, errors.New("private key required (JIT flow)")
}

keyFile, err := os.CreateTemp("", "pangolin-ssh-key-*")
if err != nil {
return "", "", nil, err
}
if _, err := keyFile.WriteString(opts.PrivateKeyPEM); err != nil {
keyFile.Close()
os.Remove(keyFile.Name())
return "", "", nil, err
}
if err := keyFile.Chmod(0o600); err != nil {
keyFile.Close()
os.Remove(keyFile.Name())
return "", "", nil, err
}
if err := keyFile.Close(); err != nil {
os.Remove(keyFile.Name())
return "", "", nil, err
}
keyPath = keyFile.Name()

if opts.Certificate != "" {
certFile, err := os.CreateTemp("", "pangolin-ssh-cert-*")
if err != nil {
os.Remove(keyPath)
return "", "", nil, err
}
if _, err := certFile.WriteString(opts.Certificate); err != nil {
certFile.Close()
os.Remove(certFile.Name())
os.Remove(keyPath)
return "", "", nil, err
}
if err := certFile.Close(); err != nil {
os.Remove(certFile.Name())
os.Remove(keyPath)
return "", "", nil, err
}
certPath = certFile.Name()
}

cleanup = func() {
os.Remove(keyPath)
if certPath != "" {
os.Remove(certPath)
}
}

return keyPath, certPath, cleanup, nil
}
Loading
Loading