forked from dan-v/awslambdaproxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ssh.go
112 lines (95 loc) · 2.74 KB
/
ssh.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package awslambdaproxy
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"io/ioutil"
"log"
"os"
"os/signal"
"os/user"
"strings"
)
type sshManager struct {
privateKey *rsa.PrivateKey
}
func (s *sshManager) getPrivateKeyBytes() []byte {
return pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(s.privateKey),
},
)
}
func (s *sshManager) getPublicKeyBytes() []byte {
publicKey, _ := ssh.NewPublicKey(&s.privateKey.PublicKey)
return ssh.MarshalAuthorizedKey(publicKey)
}
func (s *sshManager) getPublicKeyString() string {
return strings.Trim(string(s.getPublicKeyBytes()[:]), "\n")
}
func (s *sshManager) insertAuthorizedKey() error {
f, err := os.OpenFile(s.getAuthorizedKeysFile(), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
if err != nil {
return errors.Wrap(err, "Failed to open authorized_keys file")
}
defer f.Close()
if _, err = f.Write(s.getPublicKeyBytes()); err != nil {
errors.Wrap(err, "Failed to write authorized_keys file")
}
return nil
}
func (s *sshManager) removeAuthorizedKey() error {
authorizedKeysBytes, err := ioutil.ReadFile(s.getAuthorizedKeysFile())
if err != nil {
errors.Wrap(err, "Failed to read authorized_keys file")
}
lines := strings.Split(string(authorizedKeysBytes), "\n")
for i, line := range lines {
if line == s.getPublicKeyString() {
log.Println("Removed entry to authorized_keys")
lines[i] = ""
}
}
output := strings.Join(lines, "\n")
outputClean := strings.Replace(output, "\n\n", "\n", -1)
err = ioutil.WriteFile(s.getAuthorizedKeysFile(), []byte(outputClean), 0644)
if err != nil {
errors.Wrap(err, "Failed to write authorized_keys file")
}
return nil
}
func (s *sshManager) getAuthorizedKeysFile() string {
usr, _ := user.Current()
return usr.HomeDir + "/.ssh/authorized_keys"
}
// NewSSHManager generates an ssh key and adds to authorized_keys so Lambda can connect to the host
func NewSSHManager() ([]byte, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, errors.Wrap(err, "Error generating private SSH key")
}
s := &sshManager{
privateKey: privateKey,
}
log.Println("Generated SSH key: ", s.getPublicKeyString())
err = s.insertAuthorizedKey()
if err != nil {
return nil, errors.Wrap(err, "Error adding authorized key")
}
log.Println("Added entry to authorized keys file ", s.getAuthorizedKeysFile())
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func() {
for sig := range c {
log.Println("Shutting down due to ", sig.String())
log.Println("Cleaning up authorized_key file ", s.getAuthorizedKeysFile())
s.removeAuthorizedKey()
os.Exit(0)
}
}()
return s.getPrivateKeyBytes(), nil
}