/
scp.go
129 lines (113 loc) · 2.8 KB
/
scp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package executor
import (
"bufio"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/luyomo/OhMyTiUP/pkg/utils"
"golang.org/x/crypto/ssh"
)
// ScpDownload downloads a file from remote with SCP
// The implementation is partially inspired by github.com/dtylman/scp
func ScpDownload(session *ssh.Session, client *ssh.Client, src, dst string, limit int) error {
r, err := session.StdoutPipe()
if err != nil {
return err
}
bufr := bufio.NewReader(r)
w, err := session.StdinPipe()
if err != nil {
return err
}
copyF := func() error {
// parse SCP command
line, _, err := bufr.ReadLine()
if err != nil {
return err
}
if line[0] != byte('C') {
return fmt.Errorf("incorrect scp command '%b', should be 'C'", line[0])
}
mode, err := strconv.ParseUint(string(line[1:5]), 0, 32)
if err != nil {
return fmt.Errorf("error parsing file mode; %s", err)
}
// prepare dst file
targetPath := filepath.Dir(dst)
if err := utils.CreateDir(targetPath); err != nil {
return err
}
targetFile, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fs.FileMode(mode))
if err != nil {
return err
}
defer targetFile.Close()
size, err := strconv.Atoi(strings.Fields(string(line))[1])
if err != nil {
return err
}
if err := ack(w); err != nil {
return err
}
// transferring data
n, err := io.CopyN(targetFile, bufr, int64(size))
if err != nil {
return err
}
if n < int64(size) {
return fmt.Errorf("error downloading via scp, file size mismatch")
}
if err := targetFile.Sync(); err != nil {
return err
}
return ack(w)
}
copyErrC := make(chan error, 1)
go func() {
defer w.Close()
copyErrC <- copyF()
}()
remoteCmd := fmt.Sprintf("scp -f %s", src)
if limit > 0 {
remoteCmd = fmt.Sprintf("scp -l %d -f %s", limit, src)
}
err = session.Start(remoteCmd)
if err != nil {
return err
}
if err := ack(w); err != nil { // send an empty byte to start transfer
return err
}
err = <-copyErrC
if err != nil {
return err
}
return session.Wait()
}
func ack(w io.Writer) error {
msg := []byte("\x00")
n, err := w.Write(msg)
if err != nil {
return fmt.Errorf("fail to send response to remote: %s", err)
}
if n < len(msg) {
return fmt.Errorf("fail to send response to remote, size mismatch")
}
return nil
}