Skip to content

Commit

Permalink
[v11] Reduce time spent setting ssh session envs (#23833)
Browse files Browse the repository at this point in the history
* Reduce time spent setting ssh session envs

`tsh` sets a number of environment variables when setting up the
users session. Each key value pair is transmitted one at a time
in a "env" ssh request, which adds a num envs * RTT of additional
latency per session.

This introduces a new `envs@goteleport.com` request which sets
multiple environment variables in a single ssh request, which
reduces the amount of time spent setting envs down to the RTT of
a single ssh request. In order to ensure backward compat and
interoperability with OpenSSH, if the server does not recognize
the `envs@goteleport.com` request the ssh client will resort to
sending individual "env" requests.

* address feedback

* fix: use a single timer for fallback requests in tests

Co-authored-by: Alan Parra <alan.parra@goteleport.com>

* fix: remove extra whitespace

Co-authored-by: Alan Parra <alan.parra@goteleport.com>

* fix: gci

---------

Co-authored-by: Alan Parra <alan.parra@goteleport.com>
  • Loading branch information
rosstimothy and codingllama committed Mar 30, 2023
1 parent e150a97 commit 42bf611
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 32 deletions.
154 changes: 154 additions & 0 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package ssh

import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -262,3 +264,155 @@ func TestNewSession(t *testing.T) {
})
}
}

// envReqParams are parameters for env request
type envReqParams struct {
Name string
Value string
}

// TestSetEnvs verifies that client uses EnvsRequest to
// send multiple envs and falls back to sending individual "env"
// requests if the server does not support EnvsRequests.
func TestSetEnvs(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
errChan := make(chan error, 5)

expected := map[string]string{"a": "1", "b": "2", "c": "3"}

// used to collect individual envs requests
envReqC := make(chan envReqParams, 3)

srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
for {
select {
case <-ctx.Done():
return
case ch := <-channels:
switch {
case ch == nil:
return
case ch.ChannelType() == "session":
ch, reqs, err := ch.Accept()
if err != nil {
errChan <- trace.Wrap(err, "failed to accept session channel")
return
}

go func() {
defer ch.Close()
for i := 0; ; i++ {
select {
case <-ctx.Done():
return
case req := <-reqs:
if req == nil {
return
}

switch {
case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest
var envReq EnvsReq
if err := ssh.Unmarshal(req.Payload, &envReq); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}

var envs map[string]string
if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}

for k, v := range expected {
actual, ok := envs[k]
if !ok {
_ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k)))
return
}

if actual != v {
_ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual)))
return
}
}

_ = req.Reply(true, nil)
case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks
_ = req.Reply(false, nil)
case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks.
var e envReqParams
if err := ssh.Unmarshal(req.Payload, &e); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}
envReqC <- e
_ = req.Reply(true, nil)
default: // out of order or unexpected message
_ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i)))
errChan <- err
return
}
}
}
}()
default:
if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil {
errChan <- err
return
}
}
}
}
})

go srv.Run(errChan)

// create a client and open a session
conn, chans, reqs := srv.GetClient(t)
client := NewClient(conn, chans, reqs)
session, err := client.NewSession(ctx)
require.NoError(t, err)

// the first request shouldn't fall back
t.Run("envs set via envs@goteleport.com", func(t *testing.T) {
require.NoError(t, session.SetEnvs(ctx, expected))

select {
case <-envReqC:
t.Fatal("env request received instead of an envs@goteleport.com request")
default:
}
})

// subsequent requests should fall back to standard "env" requests
t.Run("envs set individually", func(t *testing.T) {
require.NoError(t, session.SetEnvs(ctx, expected))

envs := map[string]string{}
envsTimeout := time.NewTimer(3 * time.Second)
defer envsTimeout.Stop()
for i := 0; i < len(expected); i++ {
select {
case env := <-envReqC:
envs[env.Name] = env.Value
case <-envsTimeout.C:
t.Fatalf("Time out waiting for env request %d to be processed", i)
}
}

for k, v := range expected {
actual, ok := envs[k]
require.True(t, ok, "expected env %s to be set", k)
require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual)
}
})

select {
case err := <-errChan:
require.NoError(t, err)
default:
}
}
87 changes: 76 additions & 11 deletions api/observability/tracing/ssh/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ package ssh

import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
oteltrace "go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -51,7 +54,8 @@ func (s *Session) SendRequest(ctx context.Context, name string, wantReply bool,

// no need to wrap payload here, the session's channel wrapper will do it for us
s.wrapper.addContext(ctx, name)
return s.Session.SendRequest(name, wantReply, payload)
ok, err := s.Session.SendRequest(name, wantReply, payload)
return ok, trace.Wrap(err)
}

// Setenv sets an environment variable that will be applied to any
Expand All @@ -72,7 +76,66 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Setenv(name, value)
return trace.Wrap(s.Session.Setenv(name, value))
}

// SetEnvs sets environment variables that will be applied to any
// command executed by Shell or Run. If the server does not handle
// [EnvsRequest] requests then the client falls back to sending individual
// "env" requests until all provided environment variables have been set
// or an error was received.
func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error {
config := tracing.NewConfig(s.wrapper.opts)
ctx, span := config.TracerProvider.Tracer(instrumentationName).Start(
ctx,
"ssh.SetEnvs",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
semconv.RPCServiceKey.String("ssh.Session"),
semconv.RPCMethodKey.String("SendRequest"),
semconv.RPCSystemKey.String("ssh"),
),
)
defer span.End()

if len(envs) == 0 {
return nil
}

// If the server isn't Teleport fallback to individual "env" requests
if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") {
return trace.Wrap(s.setEnvFallback(ctx, envs))
}

raw, err := json.Marshal(envs)
if err != nil {
return trace.Wrap(err)
}

s.wrapper.addContext(ctx, EnvsRequest)
ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{EnvsJSON: raw}))
if err != nil {
return trace.Wrap(err)
}

// The server does not handle EnvsRequest requests so fall back
// to sending individual requests.
if !ok {
return trace.Wrap(s.setEnvFallback(ctx, envs))
}

return nil
}

// setEnvFallback sends an "env" request for each item in envs.
func (s *Session) setEnvFallback(ctx context.Context, envs map[string]string) error {
for k, v := range envs {
if err := s.Setenv(ctx, k, v); err != nil {
return trace.Wrap(err, "failed to set environment variable %s", k)
}
}

return nil
}

// RequestPty requests the association of a pty with the session on the remote host.
Expand All @@ -95,7 +158,7 @@ func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmod
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.RequestPty(term, h, w, termmodes)
return trace.Wrap(s.Session.RequestPty(term, h, w, termmodes))
}

// RequestSubsystem requests the association of a subsystem with the session on the remote host.
Expand All @@ -116,7 +179,7 @@ func (s *Session) RequestSubsystem(ctx context.Context, subsystem string) error
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.RequestSubsystem(subsystem)
return trace.Wrap(s.Session.RequestSubsystem(subsystem))
}

// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
Expand All @@ -138,7 +201,7 @@ func (s *Session) WindowChange(ctx context.Context, h, w int) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.WindowChange(h, w)
return trace.Wrap(s.Session.WindowChange(h, w))
}

// Signal sends the given signal to the remote process.
Expand All @@ -159,7 +222,7 @@ func (s *Session) Signal(ctx context.Context, sig ssh.Signal) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Signal(sig)
return trace.Wrap(s.Session.Signal(sig))
}

// Start runs cmd on the remote host. Typically, the remote
Expand All @@ -181,7 +244,7 @@ func (s *Session) Start(ctx context.Context, cmd string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Start(cmd)
return trace.Wrap(s.Session.Start(cmd))
}

// Shell starts a login shell on the remote host. A Session only
Expand All @@ -202,7 +265,7 @@ func (s *Session) Shell(ctx context.Context) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Shell()
return trace.Wrap(s.Session.Shell())
}

// Run runs cmd on the remote host. Typically, the remote
Expand Down Expand Up @@ -234,7 +297,7 @@ func (s *Session) Run(ctx context.Context, cmd string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Run(cmd)
return trace.Wrap(s.Session.Run(cmd))
}

// Output runs cmd on the remote host and returns its standard output.
Expand All @@ -254,7 +317,8 @@ func (s *Session) Output(ctx context.Context, cmd string) ([]byte, error) {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Output(cmd)
output, err := s.Session.Output(cmd)
return output, trace.Wrap(err)
}

// CombinedOutput runs cmd on the remote host and returns its combined
Expand All @@ -275,5 +339,6 @@ func (s *Session) CombinedOutput(ctx context.Context, cmd string) ([]byte, error
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.CombinedOutput(cmd)
output, err := s.Session.CombinedOutput(cmd)
return output, trace.Wrap(err)
}
13 changes: 13 additions & 0 deletions api/observability/tracing/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ import (
)

const (
// EnvsRequest sets multiple environment variables that will be applied to any
// command executed by Shell or Run.
// See [EnvsReq] for the corresponding payload.
EnvsRequest = "envs@goteleport.com"

// TracingRequest is sent by clients to server to pass along tracing context.
TracingRequest = "tracing@goteleport.com"

Expand All @@ -44,6 +49,14 @@ const (
instrumentationName = "otelssh"
)

// EnvsReq contains json marshaled key:value pairs sent as the
// payload for an [EnvsRequest].
type EnvsReq struct {
// EnvsJSON is a json marshaled map[string]string containing
// environment variables.
EnvsJSON []byte `json:"envs"`
}

// ContextFromRequest extracts any tracing data provided via an Envelope
// in the ssh.Request payload. If the payload contains an Envelope, then
// the context returned will have tracing data populated from the remote
Expand Down
22 changes: 11 additions & 11 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,22 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi
return nil, trace.Wrap(err)
}

envs := map[string]string{}

// pass language info into the remote session.
evarsToPass := []string{"LANG", "LANGUAGE"}
for _, evar := range evarsToPass {
if value := os.Getenv(evar); value != "" {
err = sess.Setenv(ctx, evar, value)
if err != nil {
log.Warn(err)
}
langVars := []string{"LANG", "LANGUAGE"}
for _, env := range langVars {
if value := os.Getenv(env); value != "" {
envs[env] = value
}
}
// pass environment variables set by client
for key, val := range ns.env {
err = sess.Setenv(ctx, key, val)
if err != nil {
log.Warn(err)
}
envs[key] = val
}

if err := sess.SetEnvs(ctx, envs); err != nil {
log.Warn(err)
}

// if agent forwarding was requested (and we have a agent to forward),
Expand Down

0 comments on commit 42bf611

Please sign in to comment.