forked from jumpserver/koko
/
server.go
117 lines (101 loc) · 3.59 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package sshd
import (
"context"
"net"
"strconv"
"time"
//"fmt"
"github.com/gliderlabs/ssh"
"github.com/pires/go-proxyproto"
gossh "golang.org/x/crypto/ssh"
"github.com/meowgen/koko/pkg/logger"
//"github.com/davecgh/go-spew/spew"
)
const (
sshChannelSession = "session"
sshChannelDirectTCPIP = "direct-tcpip"
sshSubSystemSFTP = "sftp"
)
type Server struct {
Srv *ssh.Server
}
func (s *Server) Start() {
logger.Infof("Start SSH server at %s", s.Srv.Addr)
ln, err := net.Listen("tcp", s.Srv.Addr)
if err != nil {
logger.Fatal(err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
logger.Fatal(s.Srv.Serve(proxyListener))
}
func (s *Server) Stop() {
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFunc()
logger.Fatal(s.Srv.Shutdown(ctx))
}
type SSHHandler interface {
GetSSHAddr() string
GetSSHSigner() ssh.Signer
KeyboardInteractiveAuth(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) AuthStatus
PasswordAuth(ctx ssh.Context, password string) AuthStatus
PublicKeyAuth(ctx ssh.Context, key ssh.PublicKey) AuthStatus
NextAuthMethodsHandler(ctx ssh.Context) []string
SessionHandler(ssh.Session)
SFTPHandler(ssh.Session)
LocalPortForwardingPermission(ctx ssh.Context, destinationHost string, destinationPort uint32) bool
DirectTCPIPChannelHandler(ctx ssh.Context, newChan gossh.NewChannel, destAddr string)
}
type AuthStatus ssh.AuthResult
const (
AuthFailed = AuthStatus(ssh.AuthFailed)
AuthSuccessful = AuthStatus(ssh.AuthSuccessful)
AuthPartiallySuccessful = AuthStatus(ssh.AuthPartiallySuccessful)
)
func NewSSHServer(handler SSHHandler) *Server {
srv := &ssh.Server{
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
return handler.LocalPortForwardingPermission(ctx, destinationHost, destinationPort)
},
Addr: handler.GetSSHAddr(),
KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) ssh.AuthResult {
return ssh.AuthResult(handler.KeyboardInteractiveAuth(ctx, challenger))
},
PasswordHandler: func(ctx ssh.Context, password string) ssh.AuthResult {
return ssh.AuthResult(handler.PasswordAuth(ctx, password))
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) ssh.AuthResult {
return ssh.AuthResult(handler.PublicKeyAuth(ctx, key))
},
NextAuthMethodsHandler: func(ctx ssh.Context) []string {
return handler.NextAuthMethodsHandler(ctx)
},
HostSigners: []ssh.Signer{handler.GetSSHSigner()},
Handler: handler.SessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
sshSubSystemSFTP: handler.SFTPHandler,
},
ChannelHandlers: map[string]ssh.ChannelHandler{
sshChannelSession: ssh.DefaultSessionHandler,
sshChannelDirectTCPIP: func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
localD := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &localD); err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, localD.DestAddr, localD.DestPort) {
_ = newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
return
}
dest := net.JoinHostPort(localD.DestAddr, strconv.FormatInt(int64(localD.DestPort), 10))
handler.DirectTCPIPChannelHandler(ctx, newChan, dest)
},
},
}
return &Server{srv}
}
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}