-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
sftp.go
257 lines (214 loc) · 6.16 KB
/
sftp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
package common
import (
"context"
"errors"
"fmt"
"io"
"net/url"
"os"
"os/signal"
"path/filepath"
"strconv"
"syscall"
"github.com/melbahja/goph"
"golang.org/x/crypto/ssh"
)
// SftpCopyInput is the input for CallSftpCopy
type SftpCopyInput struct {
// AllowUknownHosts allows connecting to hosts with unknown host keys
AllowUknownHosts bool
// DestinationPath is the path to copy the file to
DestinationPath string
// RemoteHost is the remote host to connect to
RemoteHost string
// SourcePath is the path to the file to copy
SourcePath string
}
// CallSftpCopy copies a file to a remote host via sftp
func CallSftpCopy(input SftpCopyInput) (SftpCopyResult, error) {
return CallSftpCopyWithContext(context.Background(), input)
}
// CallSftpCopyWithContext copies a file to a remote host via sftp with the given context
func CallSftpCopyWithContext(ctx context.Context, input SftpCopyInput) (SftpCopyResult, error) {
signals := make(chan os.Signal, 1)
signal.Notify(signals, os.Interrupt, syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGQUIT,
syscall.SIGTERM)
ctx, cancel := context.WithCancel(ctx)
go func() {
<-signals
cancel()
}()
u, err := url.Parse(input.RemoteHost)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to parse remote host: %w", err)
}
if u.Scheme == "" {
return SftpCopyResult{}, fmt.Errorf("missing remote host ssh scheme in remote host: %s", input.RemoteHost)
}
if u.Scheme != "ssh" {
return SftpCopyResult{}, fmt.Errorf("invalid remote host scheme: %s", u.Scheme)
}
username := ""
password := ""
if u.User != nil {
username = u.User.Username()
if pass, ok := u.User.Password(); ok {
password = pass
}
}
if username == "" {
username = os.Getenv("USER")
}
portStr := u.Port()
port := 0
if portStr != "" {
portVal, err := strconv.Atoi(portStr)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to parse port: %w", err)
}
port = portVal
}
if port == 0 {
port = 22
}
sshKeyPath := filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_ed25519")
if !FileExists(sshKeyPath) {
sshKeyPath = filepath.Join(os.Getenv("DOKKU_ROOT"), ".ssh/id_rsa")
}
if !FileExists(sshKeyPath) {
return SftpCopyResult{}, errors.New("ssh key not found at ~/.ssh/id_ed25519 or ~/.ssh/id_rsa")
}
cmd := SftpCopyTask{
SourcePath: input.SourcePath,
DestinationPath: input.DestinationPath,
AllowUknownHosts: input.AllowUknownHosts,
Hostname: u.Hostname(),
Port: uint(port),
Username: username,
Password: password,
SshKeyPath: sshKeyPath,
}
if os.Getenv("DOKKU_TRACE") == "1" {
cmd.PrintCommand = true
}
res, err := cmd.Execute(ctx)
if err != nil {
return res, err
}
return res, nil
}
type SftpCopyTask struct {
// SourcePath is the path to the file to copy
SourcePath string
// DestinationPath is the path to copy the file to
DestinationPath string
// Shell run the command in a bash shell.
// Note that the system must have `bash` installed in the PATH or in /bin/bash
Shell bool
// Stdin connect a reader to stdin for the command
// being executed.
Stdin io.Reader
// PrintCommand prints the command before executing
PrintCommand bool
// AllowUknownHosts allows connecting to hosts with unknown host keys
AllowUknownHosts bool
// Hostname is the hostname to connect to
Hostname string
// Port is the port to connect to
Port uint
// Username is the username to connect with
Username string
// Password is the password to connect with
Password string
// SshKeyPath is the path to the ssh key to use
SshKeyPath string
}
// SftpCopyResult is the result of executing an SftpCopyTask
type SftpCopyResult struct {
ExitErr error
Cancelled bool
}
// Execute runs the task
func (task SftpCopyTask) Execute(ctx context.Context) (SftpCopyResult, error) {
if task.SourcePath == "" {
return SftpCopyResult{}, errors.New("source path is required")
}
if task.DestinationPath == "" {
return SftpCopyResult{}, errors.New("destination path is required")
}
if task.Hostname == "" {
return SftpCopyResult{}, errors.New("hostname is required")
}
if task.SshKeyPath == "" {
return SftpCopyResult{}, errors.New("ssh key path is required")
}
if task.Username == "" {
return SftpCopyResult{}, errors.New("username is required")
}
if task.Port == 0 {
task.Port = 22
}
auth, err := goph.Key(task.SshKeyPath, "")
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to load ssh key: %w", err)
}
callback, err := goph.DefaultKnownHosts()
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to load known hosts: %w", err)
}
if task.AllowUknownHosts {
callback = ssh.InsecureIgnoreHostKey()
}
connectionConf := goph.Config{
User: task.Username,
Addr: task.Hostname,
Port: task.Port,
Timeout: goph.DefaultTimeout,
Callback: callback,
Auth: auth,
}
if task.Password != "" {
connectionConf.Auth = goph.Password(task.Password)
}
client, err := goph.NewConn(&connectionConf)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to create ssh client: %w", err)
}
defer client.Close()
// don't try to run if the context is already cancelled
if ctx.Err() != nil {
return SftpCopyResult{
Cancelled: ctx.Err() == context.Canceled,
}, ctx.Err()
}
if task.PrintCommand {
LogDebug(fmt.Sprintf("ssh %s@%s cp %s %v", task.Username, task.Hostname, task.SourcePath, task.DestinationPath))
}
sftp, err := client.NewSftp()
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to create sftp client: %w", err)
}
contents, err := os.ReadFile(task.SourcePath)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to read source file: %w", err)
}
dstFile, err := sftp.Create(task.DestinationPath)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to create destination file: %w", err)
}
_, err = dstFile.Write(contents)
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to write to destination file: %w", err)
}
err = dstFile.Close()
if err != nil {
return SftpCopyResult{}, fmt.Errorf("failed to create ssh command: %w", err)
}
err = sftp.Close()
return SftpCopyResult{
ExitErr: err,
Cancelled: ctx.Err() == context.Canceled,
}, ctx.Err()
}