Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v10] Improve web ui ssh performance #19119

Merged
merged 3 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/utils/sshutils/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package sshutils

import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
Expand All @@ -35,13 +36,18 @@ type HostKeyCallbackConfig struct {
FIPS bool
// OnCheckCert is called on SSH certificate validation.
OnCheckCert func(*ssh.Certificate)
// Clock is used to set the Checker Time
Clock clockwork.Clock
}

// Check validates the config.
func (c *HostKeyCallbackConfig) Check() error {
if c.GetHostCheckers == nil {
return trace.BadParameter("missing GetHostCheckers")
}
if c.Clock == nil {
c.Clock = clockwork.NewRealClock()
}
return nil
}

Expand All @@ -54,6 +60,7 @@ func NewHostKeyCallback(conf HostKeyCallbackConfig) (ssh.HostKeyCallback, error)
CertChecker: ssh.CertChecker{
IsHostAuthority: makeIsHostAuthorityFunc(conf.GetHostCheckers),
HostKeyFallback: conf.HostKeyFallback,
Clock: conf.Clock.Now,
},
FIPS: conf.FIPS,
OnCheckCert: conf.OnCheckCert,
Expand Down
116 changes: 79 additions & 37 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
Expand Down Expand Up @@ -84,7 +85,6 @@ type NodeClient struct {
Namespace string
Tracer oteltrace.Tracer
Client *tracessh.Client
Proxy *ProxyClient
TC *TeleportClient
OnMFA func()
FIPSEnabled bool
Expand Down Expand Up @@ -1619,6 +1619,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet
return nil, trace.ConnectionProblem(err, "failed connecting to node %v. %s",
nodeName(nodeAddress.Addr), serverErrorMsg)
}

pipeNetConn := utils.NewPipeNetConn(
proxyReader,
proxyWriter,
Expand All @@ -1632,35 +1633,9 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeDet
Auth: authMethods,
HostKeyCallback: proxy.hostKeyCallback,
}
conn, chans, reqs, err := newClientConn(ctx, pipeNetConn, nodeAddress.ProxyFormat(), sshConfig)
if err != nil {
if utils.IsHandshakeFailedError(err) {
proxySession.Close()
return nil, trace.AccessDenied(`access denied to %v connecting to %v`, user, nodeAddress)
}
return nil, trace.Wrap(err)
}

// We pass an empty channel which we close right away to ssh.NewClient
// because the client need to handle requests itself.
emptyCh := make(chan *ssh.Request)
close(emptyCh)

nc := &NodeClient{
Client: tracessh.NewClient(conn, chans, emptyCh),
Proxy: proxy,
Namespace: apidefaults.Namespace,
TC: proxy.teleportClient,
Tracer: proxy.Tracer,
FIPSEnabled: details.FIPSEnabled,
}

// Start a goroutine that will run for the duration of the client to process
// global requests from the client. Teleport clients will use this to update
// terminal sizes when the remote PTY size has changed.
go nc.handleGlobalRequests(ctx, reqs)

return nc, nil
nc, err := NewNodeClient(ctx, sshConfig, pipeNetConn, nodeAddress.ProxyFormat(), proxy.teleportClient, details.FIPSEnabled)
return nc, trace.Wrap(err)
}

// PortForwardToNode connects to the ssh server via Proxy
Expand Down Expand Up @@ -1704,11 +1679,28 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod
Auth: authMethods,
HostKeyCallback: proxy.hostKeyCallback,
}
conn, chans, reqs, err := newClientConn(ctx, proxyConn, nodeAddress.Addr, sshConfig)

nc, err := NewNodeClient(ctx, sshConfig, proxyConn, nodeAddress.Addr, proxy.teleportClient, details.FIPSEnabled)
return nc, trace.Wrap(err)
}

// NewNodeClient constructs a NodeClient that is connected to the node at nodeAddress
func NewNodeClient(ctx context.Context, sshConfig *ssh.ClientConfig, conn net.Conn, nodeAddress string, tc *TeleportClient, fipsEnabled bool) (*NodeClient, error) {
ctx, span := tc.Tracer.Start(
ctx,
"NewNodeClient",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.String("node", nodeAddress),
),
)
defer span.End()

sshconn, chans, reqs, err := newClientConn(ctx, conn, nodeAddress, sshConfig)
if err != nil {
if utils.IsHandshakeFailedError(err) {
proxyConn.Close()
return nil, trace.AccessDenied(`access denied to %v connecting to %v`, user, nodeAddress)
conn.Close()
return nil, trace.AccessDenied(`access denied to %v connecting to %v`, sshConfig.User, nodeAddress)
}
return nil, trace.Wrap(err)
}
Expand All @@ -1719,11 +1711,11 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod
close(emptyCh)

nc := &NodeClient{
Client: tracessh.NewClient(conn, chans, emptyCh),
Proxy: proxy,
Namespace: apidefaults.Namespace,
TC: proxy.teleportClient,
Tracer: proxy.Tracer,
Client: tracessh.NewClient(sshconn, chans, emptyCh),
Namespace: apidefaults.Namespace,
TC: tc,
Tracer: tc.Tracer,
FIPSEnabled: fipsEnabled,
}

// Start a goroutine that will run for the duration of the client to process
Expand All @@ -1734,6 +1726,56 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod
return nc, nil
}

// RunInteractiveShell creates an interactive shell on the node and copies stdin/stdout/stderr
// to and from the node and local shell. This will block until the interactive shell on the node
// is terminated.
func (c *NodeClient) RunInteractiveShell(ctx context.Context, mode types.SessionParticipantMode, sessToJoin types.SessionTracker) error {
ctx, span := c.Tracer.Start(
ctx,
"nodeClient/RunInteractiveShell",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

env := make(map[string]string)
env[teleport.EnvSSHJoinMode] = string(mode)
env[teleport.EnvSSHSessionReason] = c.TC.Config.Reason
env[teleport.EnvSSHSessionDisplayParticipantRequirements] = strconv.FormatBool(c.TC.Config.DisplayParticipantRequirements)
encoded, err := json.Marshal(&c.TC.Config.Invited)
if err != nil {
return trace.Wrap(err)
}

env[teleport.EnvSSHSessionInvited] = string(encoded)
for key, value := range c.TC.Env {
env[key] = value
}

nodeSession, err := newSession(ctx, c, sessToJoin, env, c.TC.Stdin, c.TC.Stdout, c.TC.Stderr, c.TC.EnableEscapeSequences)
if err != nil {
return trace.Wrap(err)
}

if err = nodeSession.runShell(ctx, mode, nil, c.TC.OnShellCreated); err != nil {
switch e := trace.Unwrap(err).(type) {
case *ssh.ExitError:
c.TC.ExitStatus = e.ExitStatus()
case *ssh.ExitMissingError:
c.TC.ExitStatus = 1
}

return trace.Wrap(err)
}

if nodeSession.ExitMsg == "" {
fmt.Fprintln(c.TC.Stderr, "the connection was closed on the remote side at ", time.Now().Format(time.RFC822))
} else {
fmt.Fprintln(c.TC.Stderr, nodeSession.ExitMsg)
}

return nil
}

func (c *NodeClient) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.Request) {
for {
select {
Expand Down
2 changes: 1 addition & 1 deletion lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi

// if agent forwarding was requested (and we have a agent to forward),
// forward the agent to endpoint.
tc := ns.nodeClient.Proxy.teleportClient
tc := ns.nodeClient.TC
targetAgent := selectKeyAgent(tc)

if targetAgent != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ func (s *localSite) dialWithAgent(params DialParams) (net.Conn, error) {
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
Clock: s.clock,
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ func (s *remoteSite) dialWithAgent(params DialParams) (net.Conn, error) {
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
Clock: s.clock,
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,9 @@ func (s *server) checkClientCert(logger *log.Entry, user string, clusterName str
}

checker := apisshutils.CertChecker{
CertChecker: ssh.CertChecker{
Clock: s.Clock.Now,
},
FIPS: s.FIPS,
}
if err := checker.CheckCert(user, cert); err != nil {
Expand Down
12 changes: 7 additions & 5 deletions lib/reversetunnel/srv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

func TestServerKeyAuth(t *testing.T) {
Expand All @@ -56,7 +57,8 @@ func TestServerKeyAuth(t *testing.T) {
require.NoError(t, err)

s := &server{
log: utils.NewLoggerForTests(),
log: utils.NewLoggerForTests(),
Config: Config{Clock: clockwork.NewRealClock()},
localAccessPoint: mockAccessPoint{
ca: ca,
},
Expand Down
73 changes: 38 additions & 35 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,42 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
}

var proxyRouter *proxy.Router
if !process.Config.Proxy.DisableReverseTunnel {
router, err := proxy.NewRouter(proxy.RouterConfig{
ClusterName: clusterName,
Log: process.log.WithField(trace.Component, "router"),
RemoteClusterGetter: accessPoint,
SiteGetter: tsrv,
TracerProvider: process.TracingProvider,
})
if err != nil {
return trace.Wrap(err)
}

proxyRouter = router
}

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}

sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{
Semaphores: accessPoint,
AccessPoint: accessPoint,
LockEnforcer: lockWatcher,
Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer},
Component: teleport.ComponentProxy,
Logger: process.log.WithField(trace.Component, "sessionctrl"),
TracerProvider: process.TracingProvider,
ServerID: serverID,
})
if err != nil {
return trace.Wrap(err)
}

// Register web proxy server
alpnHandlerForWeb := &alpnproxy.ConnectionHandlerWrapper{}
var webServer *http.Server
Expand Down Expand Up @@ -3518,6 +3554,8 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
ALPNHandler: alpnHandlerForWeb.HandleConnection,
PublicProxyAddr: process.proxyPublicAddr().Addr,
ProxyKubeAddr: proxyKubeAddr,
Router: proxyRouter,
SessionControl: sessionController,
}

webHandler, err = web.NewHandler(webConfig)
Expand Down Expand Up @@ -3612,41 +3650,6 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
})
}

var proxyRouter *proxy.Router
if !process.Config.Proxy.DisableReverseTunnel {
router, err := proxy.NewRouter(proxy.RouterConfig{
ClusterName: clusterName,
Log: process.log.WithField(trace.Component, "router"),
RemoteClusterGetter: accessPoint,
SiteGetter: tsrv,
TracerProvider: process.TracingProvider,
})
if err != nil {
return trace.Wrap(err)
}
proxyRouter = router
}

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}

sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{
Semaphores: accessPoint,
AccessPoint: accessPoint,
LockEnforcer: lockWatcher,
Emitter: &events.StreamerAndEmitter{Emitter: asyncEmitter, Streamer: streamer},
Component: teleport.ComponentProxy,
Logger: process.log.WithField(trace.Component, "sessionctrl"),
TracerProvider: process.TracingProvider,
ServerID: serverID,
})
if err != nil {
return trace.Wrap(err)
}

sshProxy, err := regular.New(
process.ExitContext(),
cfg.SSH.Addr,
Expand Down