/
server.go
113 lines (99 loc) · 3.47 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package sshd
import (
"context"
"net"
"strconv"
"time"
"github.com/gliderlabs/ssh"
"github.com/pires/go-proxyproto"
gossh "golang.org/x/crypto/ssh"
"github.com/jumpserver/koko/pkg/auth"
"github.com/jumpserver/koko/pkg/config"
"github.com/jumpserver/koko/pkg/handler"
"github.com/jumpserver/koko/pkg/jms-sdk-go/service"
"github.com/jumpserver/koko/pkg/logger"
)
const (
sshChannelSession = "session"
sshChannelDirectTCPIP = "direct-tcpip"
sshSubSystemSFTP = "sftp"
)
var (
supportedMACs = []string{"hmac-sha2-256-etm@openssh.com",
"hmac-sha2-256", "hmac-sha1"}
supportedKexAlgos = []string{
"curve25519-sha256", "curve25519-sha256@libssh.org",
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
}
)
type Server struct {
Srv *ssh.Server
Handler *handler.Server
}
func (s *Server) Start() {
logger.Infof("Start SSH server at %s", s.Srv.Addr)
ln, err := net.Listen("tcp", s.Srv.Addr)
if err != nil {
logger.Fatal(err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
logger.Fatal(s.Srv.Serve(proxyListener))
}
func (s *Server) Stop() {
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFunc()
logger.Fatal(s.Srv.Shutdown(ctx))
}
const nextAuthMethod = "keyboard-interactive"
func NewSSHServer(jmsService *service.JMService) *Server {
cf := config.GlobalConfig
addr := net.JoinHostPort(cf.BindHost, cf.SSHPort)
termCfg, err := jmsService.GetTerminalConfig()
if err != nil {
logger.Fatal(err)
}
singer, err := ParsePrivateKeyFromString(termCfg.HostKey)
if err != nil {
logger.Fatalf("Parse Terminal private key failed: %s\n", err)
}
sshHandler := handler.NewServer(termCfg, jmsService)
srv := &ssh.Server{
Addr: addr,
KeyboardInteractiveHandler: auth.SSHKeyboardInteractiveAuth,
PasswordHandler: sshHandler.PasswordAuth,
PublicKeyHandler: sshHandler.PublicKeyAuth,
AuthLogCallback: auth.SSHAuthLogCallback,
NextAuthMethodsHandler: func(ctx ssh.Context) []string { return []string{nextAuthMethod} },
HostSigners: []ssh.Signer{singer},
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
cfg := gossh.Config{MACs: supportedMACs, KeyExchanges: supportedKexAlgos}
return &gossh.ServerConfig{Config: cfg}
},
Handler: sshHandler.SessionHandler,
LocalPortForwardingCallback: sshHandler.LocalPortForwardingPermission,
SubsystemHandlers: map[string]ssh.SubsystemHandler{sshSubSystemSFTP: sshHandler.SFTPHandler},
ChannelHandlers: map[string]ssh.ChannelHandler{
sshChannelSession: ssh.DefaultSessionHandler,
sshChannelDirectTCPIP: func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
localD := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &localD); err != nil {
_ = newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, localD.DestAddr, localD.DestPort) {
_ = newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
return
}
dest := net.JoinHostPort(localD.DestAddr, strconv.FormatInt(int64(localD.DestPort), 10))
sshHandler.DirectTCPIPChannelHandler(ctx, newChan, dest)
},
},
}
return &Server{srv, sshHandler}
}
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}